Spaces:
Runtime error
Runtime error
| # models/text_and_metadata_model.py | |
| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel # Can be extended to RoBERTa, DeBERTa etc. | |
| from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME | |
| class BertWithMetadataModel(nn.Module): | |
| """ | |
| Hybrid model that combines text features (extracted by BERT) with additional | |
| numerical metadata features. The text features are processed by BERT, | |
| metadata features by a simple MLP, and then their outputs are concatenated | |
| before being fed into the final classification heads. | |
| """ | |
| # Statically set tokenizer name | |
| tokenizer_name = BERT_MODEL_NAME | |
| def __init__(self, num_labels, metadata_dim): | |
| """ | |
| Initializes the BertWithMetadataModel. | |
| Args: | |
| num_labels (list): A list where each element is the number of classes | |
| for a corresponding label column. | |
| metadata_dim (int): The number of features in the numerical metadata. | |
| """ | |
| super(BertWithMetadataModel, self).__init__() | |
| # Load pre-trained BERT model for text processing | |
| self.bert = BertModel.from_pretrained(BERT_MODEL_NAME) | |
| self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout for BERT's output | |
| # MLP for processing numerical metadata features | |
| self.metadata_mlp = nn.Sequential( | |
| nn.Linear(metadata_dim, 128), # First linear layer | |
| nn.ReLU(), # Activation function | |
| nn.Dropout(DROPOUT_RATE), # Dropout for metadata features | |
| nn.Linear(128, 64) # Second linear layer | |
| ) | |
| # Calculate the total input feature size for the classification heads. | |
| # This is the sum of BERT's pooled output size and the metadata MLP's output size. | |
| combined_feature_size = self.bert.config.hidden_size + 64 | |
| # Create classification heads, one for each label column | |
| self.classifiers = nn.ModuleList([ | |
| nn.Linear(combined_feature_size, n_classes) for n_classes in num_labels | |
| ]) | |
| def forward(self, input_ids, attention_mask, metadata): | |
| """ | |
| Performs the forward pass of the hybrid model. | |
| Args: | |
| input_ids (torch.Tensor): Tensor of token IDs for text. | |
| attention_mask (torch.Tensor): Tensor indicating attention for text. | |
| metadata (torch.Tensor): Tensor of numerical metadata features. | |
| Returns: | |
| list: A list of logit tensors, one for each classification head. | |
| """ | |
| # Process text input through BERT | |
| bert_pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output | |
| bert_pooled_output = self.dropout(bert_pooled_output) # Apply dropout | |
| # Process metadata through the MLP | |
| metadata_output = self.metadata_mlp(metadata) | |
| # Concatenate the processed text features and metadata features | |
| combined_features = torch.cat((bert_pooled_output, metadata_output), dim=1) | |
| # Pass the combined features through each classification head | |
| return [classifier(combined_features) for classifier in self.classifiers] |