File size: 2,802 Bytes
018d244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff59e33
018d244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89386d1
ff59e33
018d244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
from transformers import BertModel, AutoModel

class CommentMTLModel(nn.Module):
    """
    Multi-Task Learning model using a BERT base and separate heads for
    sentiment classification and toxicity multi-label classification.
    """
    def __init__(self, model_name, num_sentiment_labels, num_toxicity_labels, dropout_prob=0.1):
        """
        Args:
            model_name (str): Name of the pre-trained BERT model from Hugging Face.
            num_sentiment_labels (int): Number of classes for sentiment analysis.
            num_toxicity_labels (int): Number of classes for toxicity detection.
            dropout_prob (float): Dropout probability for the classification heads.
        """
        super(CommentMTLModel, self).__init__()

        # Load the pre-trained BERT model
        self.bert = AutoModel.from_pretrained(model_name)

        # Dropout layer for regularization - applied after BERT output, before heads
        self.dropout = nn.Dropout(dropout_prob)

        # --- Sentiment Head ---
        # Takes BERT's pooled output (for [CLS] token) and maps it to sentiment logits
        self.sentiment_classifier = nn.Linear(self.bert.config.hidden_size, num_sentiment_labels)

        # --- Toxicity Head ---
        # Takes BERT's pooled output and maps it to toxicity logits (multi-label)
        self.toxicity_classifier = nn.Linear(self.bert.config.hidden_size, num_toxicity_labels)

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        """
        Forward pass of the model.

        Args:
            input_ids (torch.Tensor): Tensor of input token IDs (batch_size, seq_length).
            attention_mask (torch.Tensor): Tensor of attention masks (batch_size, seq_length).

        Returns:
            dict: A dictionary containing the raw output logits for each task:
                'sentiment_logits': Logits for sentiment classification (batch_size, num_sentiment_labels).
                'toxicity_logits': Logits for toxicity multi-label classification (batch_size, num_toxicity_labels).
        """
        # Pass input through BERT model
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        # Get the pooled output
        pooled_output = outputs.pooler_output

        # Apply dropout for regularization
        pooled_output = self.dropout(pooled_output)

        # Pass the pooled output through the task-specific heads
        sentiment_logits = self.sentiment_classifier(pooled_output)
        toxicity_logits = self.toxicity_classifier(pooled_output)

        return {
            'sentiment_logits': sentiment_logits,
            'toxicity_logits': toxicity_logits
        }