File size: 1,938 Bytes
31d7b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

import torch
import torch.nn as nn
from transformers import AutoModel

class TaskHead(nn.Module):
    def __init__(self, hidden_size, n_classes, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, n_classes),
        )
    def forward(self, x): return self.net(x)

class UnifiedMASRIHead(nn.Module):
    def __init__(self, bert_model_name="T0KII/MASRIBERTv3", ft_dim=300, rnn_hidden=256, num_layers=2, dropout=0.3):
        super().__init__()
        self.bert = AutoModel.from_pretrained(bert_model_name)
        self.bert_hidden = self.bert.config.hidden_size
        self.bilstm = nn.LSTM(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        self.bigru = nn.GRU(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        self.rnn_dropout = nn.Dropout(dropout)
        combined_dim = self.bert_hidden + (rnn_hidden * 4)
        self.sarcasm_head = TaskHead(combined_dim, 2, dropout=0.5)
        self.sentiment_head = TaskHead(combined_dim, 3, dropout=0.3)
        self.emotion_head = TaskHead(combined_dim, 8, dropout=0.3)

    def forward(self, input_ids, attention_mask, ft_embeds, task_name):
        bert_out = self.bert(input_ids, attention_mask=attention_mask)
        cls_vec = bert_out.last_hidden_state[:, 0, :]
        lstm_out, _ = self.bilstm(ft_embeds)
        gru_out, _ = self.bigru(ft_embeds)
        rnn_feat = torch.cat([lstm_out[:, -1, :], gru_out[:, -1, :]], dim=1)
        combined = torch.cat([cls_vec, self.rnn_dropout(rnn_feat)], dim=1)
        if task_name == 'sarcasm': return self.sarcasm_head(combined)
        elif task_name == 'sentiment': return self.sentiment_head(combined)
        elif task_name == 'emotion': return self.emotion_head(combined)