taMASRIBERT / modeling.py
T0KII's picture
Initial Deployment Package with Inference Script
31d7b01 verified
Raw
History Blame Contribute Delete
1.94 kB
import torch
import torch.nn as nn
from transformers import AutoModel
class TaskHead(nn.Module):
def __init__(self, hidden_size, n_classes, dropout=0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, n_classes),
)
def forward(self, x): return self.net(x)
class UnifiedMASRIHead(nn.Module):
def __init__(self, bert_model_name="T0KII/MASRIBERTv3", ft_dim=300, rnn_hidden=256, num_layers=2, dropout=0.3):
super().__init__()
self.bert = AutoModel.from_pretrained(bert_model_name)
self.bert_hidden = self.bert.config.hidden_size
self.bilstm = nn.LSTM(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
self.bigru = nn.GRU(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
self.rnn_dropout = nn.Dropout(dropout)
combined_dim = self.bert_hidden + (rnn_hidden * 4)
self.sarcasm_head = TaskHead(combined_dim, 2, dropout=0.5)
self.sentiment_head = TaskHead(combined_dim, 3, dropout=0.3)
self.emotion_head = TaskHead(combined_dim, 8, dropout=0.3)
def forward(self, input_ids, attention_mask, ft_embeds, task_name):
bert_out = self.bert(input_ids, attention_mask=attention_mask)
cls_vec = bert_out.last_hidden_state[:, 0, :]
lstm_out, _ = self.bilstm(ft_embeds)
gru_out, _ = self.bigru(ft_embeds)
rnn_feat = torch.cat([lstm_out[:, -1, :], gru_out[:, -1, :]], dim=1)
combined = torch.cat([cls_vec, self.rnn_dropout(rnn_feat)], dim=1)
if task_name == 'sarcasm': return self.sarcasm_head(combined)
elif task_name == 'sentiment': return self.sentiment_head(combined)
elif task_name == 'emotion': return self.emotion_head(combined)