File size: 5,500 Bytes
5e6ef00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# models/parallel_bert_deberta.py

import torch
import torch.nn as nn
from transformers import BertModel, DebertaModel
from config import DROPOUT_RATE, BERT_MODEL_NAME, DEBERTA_MODEL_NAME # Import model names

class Attention(nn.Module):
    """
    Simple Attention layer to compute a context vector from a sequence of hidden states.
    It learns a single weight for each hidden state in the sequence, then uses softmax
    to normalize these weights and compute a weighted sum of the hidden states.
    """
    def __init__(self, hidden_size):
        """
        Initializes the Attention layer.

        Args:
            hidden_size (int): The dimensionality of the input hidden states.
        """
        super(Attention, self).__init__()
        # A linear layer to project the hidden state to a single scalar (attention score)
        self.attn = nn.Linear(hidden_size, 1)

    def forward(self, encoder_output):
        """
        Performs the forward pass of the attention mechanism.

        Args:
            encoder_output (torch.Tensor): Tensor of hidden states from an encoder.
                                           Shape: (batch_size, sequence_length, hidden_size)

        Returns:
            torch.Tensor: The context vector, a weighted sum of the hidden states.
                          Shape: (batch_size, hidden_size)
        """
        # Calculate raw attention scores
        # self.attn(encoder_output) -> (batch_size, sequence_length, 1)
        # .squeeze(-1) removes the last dimension, making it (batch_size, sequence_length)
        attn_weights = torch.softmax(self.attn(encoder_output).squeeze(-1), dim=1)

        # Compute the context vector as a weighted sum of encoder_output.
        # attn_weights.unsqueeze(-1) adds a dimension for broadcasting: (batch_size, sequence_length, 1)
        # This allows element-wise multiplication with encoder_output.
        # torch.sum(..., dim=1) sums along the sequence_length dimension.
        context_vector = torch.sum(attn_weights.unsqueeze(-1) * encoder_output, dim=1)
        return context_vector

class ParallelMultiOutputModel(nn.Module):
    """
    Hybrid model that leverages both BERT and DeBERTa in parallel.
    It extracts features from both models, applies an attention mechanism to their outputs,
    projects these attended features to a common dimension, concatenates them, and then
    uses this combined representation for multi-output classification.
    """
    # Statically set tokenizer name to BERT's for this combined model
    # (assuming BERT's tokenizer is compatible or primary for combined input)
    tokenizer_name = BERT_MODEL_NAME

    def __init__(self, num_labels):
        """
        Initializes the ParallelMultiOutputModel.

        Args:
            num_labels (list): A list where each element is the number of classes
                                for a corresponding label column.
        """
        super(ParallelMultiOutputModel, self).__init__()
        # Load pre-trained BERT and DeBERTa models
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
        self.deberta = DebertaModel.from_pretrained(DEBERTA_MODEL_NAME)

        # Initialize attention layers for each backbone model
        self.attn_bert = Attention(self.bert.config.hidden_size)
        self.attn_deberta = Attention(self.deberta.config.hidden_size)

        # Projection layers to reduce dimensionality of the context vectors
        # before concatenation. This helps manage the combined feature size.
        self.proj_bert = nn.Linear(self.bert.config.hidden_size, 256)
        self.proj_deberta = nn.Linear(self.deberta.config.hidden_size, 256)

        self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization

        # Define classification heads. The input feature size is the sum of
        # the projected sizes from BERT and DeBERTa (256 + 256 = 512).
        self.classifiers = nn.ModuleList([
            nn.Linear(512, n_classes) for n_classes in num_labels
        ])

    def forward(self, input_ids, attention_mask):
        """
        Performs the forward pass of the parallel model.

        Args:
            input_ids (torch.Tensor): Tensor of token IDs.
            attention_mask (torch.Tensor): Tensor indicating attention.

        Returns:
            list: A list of logit tensors, one for each classification head.
        """
        # Get the last hidden states (sequence of hidden states for all tokens)
        # from both BERT and DeBERTa. These are typically used with attention.
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        deberta_output = self.deberta(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        # Apply attention to get a single context vector from each model's output
        context_bert = self.attn_bert(bert_output)
        context_deberta = self.attn_deberta(deberta_output)

        # Project the context vectors to their reduced dimensions
        reduced_bert = self.proj_bert(context_bert)
        reduced_deberta = self.proj_deberta(context_deberta)

        # Concatenate the reduced feature vectors from both models
        combined = torch.cat((reduced_bert, reduced_deberta), dim=1)
        combined = self.dropout(combined) # Apply dropout to the combined features

        # Pass the combined features through each classification head
        return [classifier(combined) for classifier in self.classifiers]