File size: 916 Bytes
0f74a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model

class Wav2Vec2_LSTM_MultiTask(nn.Module):
    def __init__(self, num_emotions):
        super().__init__()

        self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")

        self.lstm = nn.LSTM(
            input_size=768,
            hidden_size=256,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )

        self.shared_fc = nn.Linear(512, 256)
        self.emotion_head = nn.Linear(256, num_emotions)
        self.stress_head = nn.Linear(256, 1)

    def forward(self, input_values):
        outputs = self.wav2vec(input_values)
        x = outputs.last_hidden_state

        lstm_out, _ = self.lstm(x)
        pooled = torch.mean(lstm_out, dim=1)

        shared = torch.relu(self.shared_fc(pooled))
        return self.emotion_head(shared), self.stress_head(shared)