import torch import torch.nn as nn from transformers import AutoModel class TaskHead(nn.Module): def __init__(self, hidden_size, n_classes, dropout=0.3): super().__init__() self.net = nn.Sequential( nn.Linear(hidden_size, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, n_classes), ) def forward(self, x): return self.net(x) class UnifiedMASRIHead(nn.Module): def __init__(self, bert_model_name="T0KII/MASRIBERTv3", ft_dim=300, rnn_hidden=256, num_layers=2, dropout=0.3): super().__init__() self.bert = AutoModel.from_pretrained(bert_model_name) self.bert_hidden = self.bert.config.hidden_size self.bilstm = nn.LSTM(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout) self.bigru = nn.GRU(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout) self.rnn_dropout = nn.Dropout(dropout) combined_dim = self.bert_hidden + (rnn_hidden * 4) self.sarcasm_head = TaskHead(combined_dim, 2, dropout=0.5) self.sentiment_head = TaskHead(combined_dim, 3, dropout=0.3) self.emotion_head = TaskHead(combined_dim, 8, dropout=0.3) def forward(self, input_ids, attention_mask, ft_embeds, task_name): bert_out = self.bert(input_ids, attention_mask=attention_mask) cls_vec = bert_out.last_hidden_state[:, 0, :] lstm_out, _ = self.bilstm(ft_embeds) gru_out, _ = self.bigru(ft_embeds) rnn_feat = torch.cat([lstm_out[:, -1, :], gru_out[:, -1, :]], dim=1) combined = torch.cat([cls_vec, self.rnn_dropout(rnn_feat)], dim=1) if task_name == 'sarcasm': return self.sarcasm_head(combined) elif task_name == 'sentiment': return self.sentiment_head(combined) elif task_name == 'emotion': return self.emotion_head(combined)