HaryaniAnjali commited on
Commit
6c8529b
·
verified ·
1 Parent(s): 14d3de3

Create emotion_model.py

Browse files
Files changed (1) hide show
  1. emotion_model.py +52 -0
emotion_model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+
6
+ # Load the Keras model
7
+ keras_model = tf.keras.models.load_model('wav2vec_model.h5')
8
+
9
+ # Create a PyTorch model with the same architecture
10
+ class EmotionClassifier(torch.nn.Module):
11
+ def __init__(self, input_shape, num_classes):
12
+ super().__init__()
13
+ # Adjust this architecture to match your Keras model
14
+ self.flatten = torch.nn.Flatten()
15
+ self.layers = torch.nn.Sequential(
16
+ torch.nn.Linear(input_shape, 128),
17
+ torch.nn.ReLU(),
18
+ torch.nn.Dropout(0.3),
19
+ torch.nn.Linear(128, 64),
20
+ torch.nn.ReLU(),
21
+ torch.nn.Dropout(0.3),
22
+ torch.nn.Linear(64, num_classes)
23
+ )
24
+
25
+ def forward(self, x):
26
+ x = self.flatten(x)
27
+ return self.layers(x)
28
+
29
+ # Create PyTorch model
30
+ # Adjust these parameters based on your Keras model
31
+ input_shape = 13 * 128 # n_mfcc * max_length
32
+ num_classes = 7 # Number of emotions
33
+ pytorch_model = EmotionClassifier(input_shape, num_classes)
34
+
35
+ # Copy weights from Keras to PyTorch
36
+ # This would need to be adjusted based on your exact architecture
37
+ for i, layer in enumerate(keras_model.layers):
38
+ if isinstance(layer, tf.keras.layers.Dense):
39
+ # Get Keras weights and bias
40
+ keras_weights = layer.get_weights()[0]
41
+ keras_bias = layer.get_weights()[1]
42
+
43
+ # Find the corresponding PyTorch layer
44
+ # This is simplified; you'd need to match layers properly
45
+ pytorch_layer = pytorch_model.layers[i * 2]
46
+
47
+ # Copy weights and bias
48
+ pytorch_layer.weight.data = torch.tensor(keras_weights.T, dtype=torch.float32)
49
+ pytorch_layer.bias.data = torch.tensor(keras_bias, dtype=torch.float32)
50
+
51
+ # Save the PyTorch model
52
+ torch.save(pytorch_model.state_dict(), 'emotion_model.pt')