File size: 1,938 Bytes
31d7b01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn as nn
from transformers import AutoModel
class TaskHead(nn.Module):
def __init__(self, hidden_size, n_classes, dropout=0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, n_classes),
)
def forward(self, x): return self.net(x)
class UnifiedMASRIHead(nn.Module):
def __init__(self, bert_model_name="T0KII/MASRIBERTv3", ft_dim=300, rnn_hidden=256, num_layers=2, dropout=0.3):
super().__init__()
self.bert = AutoModel.from_pretrained(bert_model_name)
self.bert_hidden = self.bert.config.hidden_size
self.bilstm = nn.LSTM(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
self.bigru = nn.GRU(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
self.rnn_dropout = nn.Dropout(dropout)
combined_dim = self.bert_hidden + (rnn_hidden * 4)
self.sarcasm_head = TaskHead(combined_dim, 2, dropout=0.5)
self.sentiment_head = TaskHead(combined_dim, 3, dropout=0.3)
self.emotion_head = TaskHead(combined_dim, 8, dropout=0.3)
def forward(self, input_ids, attention_mask, ft_embeds, task_name):
bert_out = self.bert(input_ids, attention_mask=attention_mask)
cls_vec = bert_out.last_hidden_state[:, 0, :]
lstm_out, _ = self.bilstm(ft_embeds)
gru_out, _ = self.bigru(ft_embeds)
rnn_feat = torch.cat([lstm_out[:, -1, :], gru_out[:, -1, :]], dim=1)
combined = torch.cat([cls_vec, self.rnn_dropout(rnn_feat)], dim=1)
if task_name == 'sarcasm': return self.sarcasm_head(combined)
elif task_name == 'sentiment': return self.sentiment_head(combined)
elif task_name == 'emotion': return self.emotion_head(combined)
|