ashutoshroy02 commited on
Commit
0f74a15
·
verified ·
1 Parent(s): 3c0a049

Create model.py

Browse files
Files changed (1) hide show
  1. src/model.py +31 -0
src/model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import Wav2Vec2Model
4
+
5
+ class Wav2Vec2_LSTM_MultiTask(nn.Module):
6
+ def __init__(self, num_emotions):
7
+ super().__init__()
8
+
9
+ self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
10
+
11
+ self.lstm = nn.LSTM(
12
+ input_size=768,
13
+ hidden_size=256,
14
+ num_layers=2,
15
+ batch_first=True,
16
+ bidirectional=True
17
+ )
18
+
19
+ self.shared_fc = nn.Linear(512, 256)
20
+ self.emotion_head = nn.Linear(256, num_emotions)
21
+ self.stress_head = nn.Linear(256, 1)
22
+
23
+ def forward(self, input_values):
24
+ outputs = self.wav2vec(input_values)
25
+ x = outputs.last_hidden_state
26
+
27
+ lstm_out, _ = self.lstm(x)
28
+ pooled = torch.mean(lstm_out, dim=1)
29
+
30
+ shared = torch.relu(self.shared_fc(pooled))
31
+ return self.emotion_head(shared), self.stress_head(shared)