File size: 1,017 Bytes
18b3a2d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model
class Wav2Vec2_LSTM_MultiTask(nn.Module):
def __init__(self, num_emotions):
super().__init__()
self.wav2vec = Wav2Vec2Model.from_pretrained(
"facebook/wav2vec2-base"
)
self.lstm = nn.LSTM(
input_size=768,
hidden_size=256,
num_layers=2,
batch_first=True,
bidirectional=True
)
self.shared_fc = nn.Linear(512, 256)
self.emotion_head = nn.Linear(256, num_emotions)
self.stress_head = nn.Linear(256, 1)
def forward(self, input_values):
outputs = self.wav2vec(input_values)
x = outputs.last_hidden_state
lstm_out, _ = self.lstm(x)
pooled = torch.mean(lstm_out, dim=1)
shared = torch.relu(self.shared_fc(pooled))
emotion_logits = self.emotion_head(shared)
stress_value = self.stress_head(shared)
return emotion_logits, stress_value
|