Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel, AutoModel | |
| class CommentMTLModel(nn.Module): | |
| """ | |
| Multi-Task Learning model using a BERT base and separate heads for | |
| sentiment classification and toxicity multi-label classification. | |
| """ | |
| def __init__(self, model_name, num_sentiment_labels, num_toxicity_labels, dropout_prob=0.1): | |
| """ | |
| Args: | |
| model_name (str): Name of the pre-trained BERT model from Hugging Face. | |
| num_sentiment_labels (int): Number of classes for sentiment analysis. | |
| num_toxicity_labels (int): Number of classes for toxicity detection. | |
| dropout_prob (float): Dropout probability for the classification heads. | |
| """ | |
| super(CommentMTLModel, self).__init__() | |
| # Load the pre-trained BERT model | |
| self.bert = AutoModel.from_pretrained(model_name) | |
| # Dropout layer for regularization - applied after BERT output, before heads | |
| self.dropout = nn.Dropout(dropout_prob) | |
| # --- Sentiment Head --- | |
| # Takes BERT's pooled output (for [CLS] token) and maps it to sentiment logits | |
| self.sentiment_classifier = nn.Linear(self.bert.config.hidden_size, num_sentiment_labels) | |
| # --- Toxicity Head --- | |
| # Takes BERT's pooled output and maps it to toxicity logits (multi-label) | |
| self.toxicity_classifier = nn.Linear(self.bert.config.hidden_size, num_toxicity_labels) | |
| def forward(self, input_ids, attention_mask, token_type_ids=None): | |
| """ | |
| Forward pass of the model. | |
| Args: | |
| input_ids (torch.Tensor): Tensor of input token IDs (batch_size, seq_length). | |
| attention_mask (torch.Tensor): Tensor of attention masks (batch_size, seq_length). | |
| Returns: | |
| dict: A dictionary containing the raw output logits for each task: | |
| 'sentiment_logits': Logits for sentiment classification (batch_size, num_sentiment_labels). | |
| 'toxicity_logits': Logits for toxicity multi-label classification (batch_size, num_toxicity_labels). | |
| """ | |
| # Pass input through BERT model | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids | |
| ) | |
| # Get the pooled output | |
| pooled_output = outputs.pooler_output | |
| # Apply dropout for regularization | |
| pooled_output = self.dropout(pooled_output) | |
| # Pass the pooled output through the task-specific heads | |
| sentiment_logits = self.sentiment_classifier(pooled_output) | |
| toxicity_logits = self.toxicity_classifier(pooled_output) | |
| return { | |
| 'sentiment_logits': sentiment_logits, | |
| 'toxicity_logits': toxicity_logits | |
| } |