import torch import torch.nn as nn from transformers import BertModel, AutoModel class CommentMTLModel(nn.Module): """ Multi-Task Learning model using a BERT base and separate heads for sentiment classification and toxicity multi-label classification. """ def __init__(self, model_name, num_sentiment_labels, num_toxicity_labels, dropout_prob=0.1): """ Args: model_name (str): Name of the pre-trained BERT model from Hugging Face. num_sentiment_labels (int): Number of classes for sentiment analysis. num_toxicity_labels (int): Number of classes for toxicity detection. dropout_prob (float): Dropout probability for the classification heads. """ super(CommentMTLModel, self).__init__() # Load the pre-trained BERT model self.bert = AutoModel.from_pretrained(model_name) # Dropout layer for regularization - applied after BERT output, before heads self.dropout = nn.Dropout(dropout_prob) # --- Sentiment Head --- # Takes BERT's pooled output (for [CLS] token) and maps it to sentiment logits self.sentiment_classifier = nn.Linear(self.bert.config.hidden_size, num_sentiment_labels) # --- Toxicity Head --- # Takes BERT's pooled output and maps it to toxicity logits (multi-label) self.toxicity_classifier = nn.Linear(self.bert.config.hidden_size, num_toxicity_labels) def forward(self, input_ids, attention_mask, token_type_ids=None): """ Forward pass of the model. Args: input_ids (torch.Tensor): Tensor of input token IDs (batch_size, seq_length). attention_mask (torch.Tensor): Tensor of attention masks (batch_size, seq_length). Returns: dict: A dictionary containing the raw output logits for each task: 'sentiment_logits': Logits for sentiment classification (batch_size, num_sentiment_labels). 'toxicity_logits': Logits for toxicity multi-label classification (batch_size, num_toxicity_labels). """ # Pass input through BERT model outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) # Get the pooled output pooled_output = outputs.pooler_output # Apply dropout for regularization pooled_output = self.dropout(pooled_output) # Pass the pooled output through the task-specific heads sentiment_logits = self.sentiment_classifier(pooled_output) toxicity_logits = self.toxicity_classifier(pooled_output) return { 'sentiment_logits': sentiment_logits, 'toxicity_logits': toxicity_logits }