voting-ensemble / models /text_and_metadata_model.py
namanpenguin's picture
Upload 10 files
5e6ef00 verified
# models/text_and_metadata_model.py
import torch
import torch.nn as nn
from transformers import BertModel # Can be extended to RoBERTa, DeBERTa etc.
from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME
class BertWithMetadataModel(nn.Module):
"""
Hybrid model that combines text features (extracted by BERT) with additional
numerical metadata features. The text features are processed by BERT,
metadata features by a simple MLP, and then their outputs are concatenated
before being fed into the final classification heads.
"""
# Statically set tokenizer name
tokenizer_name = BERT_MODEL_NAME
def __init__(self, num_labels, metadata_dim):
"""
Initializes the BertWithMetadataModel.
Args:
num_labels (list): A list where each element is the number of classes
for a corresponding label column.
metadata_dim (int): The number of features in the numerical metadata.
"""
super(BertWithMetadataModel, self).__init__()
# Load pre-trained BERT model for text processing
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout for BERT's output
# MLP for processing numerical metadata features
self.metadata_mlp = nn.Sequential(
nn.Linear(metadata_dim, 128), # First linear layer
nn.ReLU(), # Activation function
nn.Dropout(DROPOUT_RATE), # Dropout for metadata features
nn.Linear(128, 64) # Second linear layer
)
# Calculate the total input feature size for the classification heads.
# This is the sum of BERT's pooled output size and the metadata MLP's output size.
combined_feature_size = self.bert.config.hidden_size + 64
# Create classification heads, one for each label column
self.classifiers = nn.ModuleList([
nn.Linear(combined_feature_size, n_classes) for n_classes in num_labels
])
def forward(self, input_ids, attention_mask, metadata):
"""
Performs the forward pass of the hybrid model.
Args:
input_ids (torch.Tensor): Tensor of token IDs for text.
attention_mask (torch.Tensor): Tensor indicating attention for text.
metadata (torch.Tensor): Tensor of numerical metadata features.
Returns:
list: A list of logit tensors, one for each classification head.
"""
# Process text input through BERT
bert_pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
bert_pooled_output = self.dropout(bert_pooled_output) # Apply dropout
# Process metadata through the MLP
metadata_output = self.metadata_mlp(metadata)
# Concatenate the processed text features and metadata features
combined_features = torch.cat((bert_pooled_output, metadata_output), dim=1)
# Pass the combined features through each classification head
return [classifier(combined_features) for classifier in self.classifiers]