| import torch | |
| import torch.nn as nn | |
| from transformers import Wav2Vec2Model | |
| class Wav2Vec2_LSTM_MultiTask(nn.Module): | |
| def __init__(self, num_emotions): | |
| super().__init__() | |
| self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| self.lstm = nn.LSTM( | |
| input_size=768, | |
| hidden_size=256, | |
| num_layers=2, | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| self.shared_fc = nn.Linear(512, 256) | |
| self.emotion_head = nn.Linear(256, num_emotions) | |
| self.stress_head = nn.Linear(256, 1) | |
| def forward(self, input_values): | |
| outputs = self.wav2vec(input_values) | |
| x = outputs.last_hidden_state | |
| lstm_out, _ = self.lstm(x) | |
| pooled = torch.mean(lstm_out, dim=1) | |
| shared = torch.relu(self.shared_fc(pooled)) | |
| return self.emotion_head(shared), self.stress_head(shared) | |