# models/text_and_metadata_model.py import torch import torch.nn as nn from transformers import BertModel # Can be extended to RoBERTa, DeBERTa etc. from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME class BertWithMetadataModel(nn.Module): """ Hybrid model that combines text features (extracted by BERT) with additional numerical metadata features. The text features are processed by BERT, metadata features by a simple MLP, and then their outputs are concatenated before being fed into the final classification heads. """ # Statically set tokenizer name tokenizer_name = BERT_MODEL_NAME def __init__(self, num_labels, metadata_dim): """ Initializes the BertWithMetadataModel. Args: num_labels (list): A list where each element is the number of classes for a corresponding label column. metadata_dim (int): The number of features in the numerical metadata. """ super(BertWithMetadataModel, self).__init__() # Load pre-trained BERT model for text processing self.bert = BertModel.from_pretrained(BERT_MODEL_NAME) self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout for BERT's output # MLP for processing numerical metadata features self.metadata_mlp = nn.Sequential( nn.Linear(metadata_dim, 128), # First linear layer nn.ReLU(), # Activation function nn.Dropout(DROPOUT_RATE), # Dropout for metadata features nn.Linear(128, 64) # Second linear layer ) # Calculate the total input feature size for the classification heads. # This is the sum of BERT's pooled output size and the metadata MLP's output size. combined_feature_size = self.bert.config.hidden_size + 64 # Create classification heads, one for each label column self.classifiers = nn.ModuleList([ nn.Linear(combined_feature_size, n_classes) for n_classes in num_labels ]) def forward(self, input_ids, attention_mask, metadata): """ Performs the forward pass of the hybrid model. Args: input_ids (torch.Tensor): Tensor of token IDs for text. attention_mask (torch.Tensor): Tensor indicating attention for text. metadata (torch.Tensor): Tensor of numerical metadata features. Returns: list: A list of logit tensors, one for each classification head. """ # Process text input through BERT bert_pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output bert_pooled_output = self.dropout(bert_pooled_output) # Apply dropout # Process metadata through the MLP metadata_output = self.metadata_mlp(metadata) # Concatenate the processed text features and metadata features combined_features = torch.cat((bert_pooled_output, metadata_output), dim=1) # Pass the combined features through each classification head return [classifier(combined_features) for classifier in self.classifiers]