Raemih commited on
Commit
efeeffd
·
verified ·
1 Parent(s): 9bb9236

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +82 -0
model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel
4
+
5
+
6
+ class LanguageIdentificationLayer(nn.Module):
7
+ def __init__(self, hidden_size, num_languages=3):
8
+ super().__init__()
9
+
10
+ self.lid_head = nn.Sequential(
11
+ nn.Linear(hidden_size, hidden_size),
12
+ nn.ReLU(),
13
+ nn.Dropout(0.1),
14
+ nn.Linear(hidden_size, num_languages)
15
+ )
16
+
17
+ def forward(self, x):
18
+ return self.lid_head(x)
19
+
20
+
21
+ class LanguageAwareEmotionHead(nn.Module):
22
+ def __init__(self, hidden_size, num_emotions=5, num_languages=3):
23
+ super().__init__()
24
+
25
+ self.lang_embeddings = nn.Embedding(num_languages, hidden_size)
26
+
27
+ self.pre_classifier = nn.Linear(hidden_size, hidden_size)
28
+
29
+ self.classifier = nn.Linear(hidden_size, num_emotions)
30
+
31
+ self.dropout = nn.Dropout(0.1)
32
+
33
+ def forward(self, features, language_logits):
34
+
35
+ language_ids = torch.argmax(language_logits, dim=-1)
36
+
37
+ lang_embed = self.lang_embeddings(language_ids)
38
+
39
+ features = features + lang_embed
40
+
41
+ features = torch.relu(self.pre_classifier(features))
42
+
43
+ features = self.dropout(features)
44
+
45
+ return self.classifier(features)
46
+
47
+
48
+ class MMSForMultilingualSER(Wav2Vec2PreTrainedModel):
49
+
50
+ def __init__(self, config):
51
+ super().__init__(config)
52
+
53
+ self.wav2vec2 = Wav2Vec2Model(config)
54
+
55
+ hidden_size = config.hidden_size
56
+
57
+ self.lid_layer = LanguageIdentificationLayer(hidden_size)
58
+
59
+ self.emotion_head = LanguageAwareEmotionHead(hidden_size)
60
+
61
+ self.dropout = nn.Dropout(0.1)
62
+
63
+ self.init_weights()
64
+
65
+ def forward(self, input_values, attention_mask=None):
66
+
67
+ outputs = self.wav2vec2(
68
+ input_values,
69
+ attention_mask=attention_mask
70
+ )
71
+
72
+ hidden_states = outputs.last_hidden_state
73
+
74
+ pooled = hidden_states.mean(dim=1)
75
+
76
+ pooled = self.dropout(pooled)
77
+
78
+ language_logits = self.lid_layer(pooled)
79
+
80
+ emotion_logits = self.emotion_head(pooled, language_logits)
81
+
82
+ return emotion_logits