File size: 3,192 Bytes
5e6ef00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# models/text_and_metadata_model.py

import torch
import torch.nn as nn
from transformers import BertModel # Can be extended to RoBERTa, DeBERTa etc.
from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME

class BertWithMetadataModel(nn.Module):
    """
    Hybrid model that combines text features (extracted by BERT) with additional
    numerical metadata features. The text features are processed by BERT,
    metadata features by a simple MLP, and then their outputs are concatenated
    before being fed into the final classification heads.
    """
    # Statically set tokenizer name
    tokenizer_name = BERT_MODEL_NAME

    def __init__(self, num_labels, metadata_dim):
        """
        Initializes the BertWithMetadataModel.

        Args:
            num_labels (list): A list where each element is the number of classes
                                for a corresponding label column.
            metadata_dim (int): The number of features in the numerical metadata.
        """
        super(BertWithMetadataModel, self).__init__()
        # Load pre-trained BERT model for text processing
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
        self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout for BERT's output

        # MLP for processing numerical metadata features
        self.metadata_mlp = nn.Sequential(
            nn.Linear(metadata_dim, 128),  # First linear layer
            nn.ReLU(),                     # Activation function
            nn.Dropout(DROPOUT_RATE),      # Dropout for metadata features
            nn.Linear(128, 64)             # Second linear layer
        )

        # Calculate the total input feature size for the classification heads.
        # This is the sum of BERT's pooled output size and the metadata MLP's output size.
        combined_feature_size = self.bert.config.hidden_size + 64

        # Create classification heads, one for each label column
        self.classifiers = nn.ModuleList([
            nn.Linear(combined_feature_size, n_classes) for n_classes in num_labels
        ])

    def forward(self, input_ids, attention_mask, metadata):
        """
        Performs the forward pass of the hybrid model.

        Args:
            input_ids (torch.Tensor): Tensor of token IDs for text.
            attention_mask (torch.Tensor): Tensor indicating attention for text.
            metadata (torch.Tensor): Tensor of numerical metadata features.

        Returns:
            list: A list of logit tensors, one for each classification head.
        """
        # Process text input through BERT
        bert_pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        bert_pooled_output = self.dropout(bert_pooled_output) # Apply dropout

        # Process metadata through the MLP
        metadata_output = self.metadata_mlp(metadata)

        # Concatenate the processed text features and metadata features
        combined_features = torch.cat((bert_pooled_output, metadata_output), dim=1)

        # Pass the combined features through each classification head
        return [classifier(combined_features) for classifier in self.classifiers]