CommentResponse / model.py
Jet-12138's picture
Update model.py
89386d1 verified
import torch
import torch.nn as nn
from transformers import BertModel, AutoModel
class CommentMTLModel(nn.Module):
"""
Multi-Task Learning model using a BERT base and separate heads for
sentiment classification and toxicity multi-label classification.
"""
def __init__(self, model_name, num_sentiment_labels, num_toxicity_labels, dropout_prob=0.1):
"""
Args:
model_name (str): Name of the pre-trained BERT model from Hugging Face.
num_sentiment_labels (int): Number of classes for sentiment analysis.
num_toxicity_labels (int): Number of classes for toxicity detection.
dropout_prob (float): Dropout probability for the classification heads.
"""
super(CommentMTLModel, self).__init__()
# Load the pre-trained BERT model
self.bert = AutoModel.from_pretrained(model_name)
# Dropout layer for regularization - applied after BERT output, before heads
self.dropout = nn.Dropout(dropout_prob)
# --- Sentiment Head ---
# Takes BERT's pooled output (for [CLS] token) and maps it to sentiment logits
self.sentiment_classifier = nn.Linear(self.bert.config.hidden_size, num_sentiment_labels)
# --- Toxicity Head ---
# Takes BERT's pooled output and maps it to toxicity logits (multi-label)
self.toxicity_classifier = nn.Linear(self.bert.config.hidden_size, num_toxicity_labels)
def forward(self, input_ids, attention_mask, token_type_ids=None):
"""
Forward pass of the model.
Args:
input_ids (torch.Tensor): Tensor of input token IDs (batch_size, seq_length).
attention_mask (torch.Tensor): Tensor of attention masks (batch_size, seq_length).
Returns:
dict: A dictionary containing the raw output logits for each task:
'sentiment_logits': Logits for sentiment classification (batch_size, num_sentiment_labels).
'toxicity_logits': Logits for toxicity multi-label classification (batch_size, num_toxicity_labels).
"""
# Pass input through BERT model
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
# Get the pooled output
pooled_output = outputs.pooler_output
# Apply dropout for regularization
pooled_output = self.dropout(pooled_output)
# Pass the pooled output through the task-specific heads
sentiment_logits = self.sentiment_classifier(pooled_output)
toxicity_logits = self.toxicity_classifier(pooled_output)
return {
'sentiment_logits': sentiment_logits,
'toxicity_logits': toxicity_logits
}