Spaces:
Runtime error
Runtime error
Upload 10 files
Browse files- models/__pycache__/bert_model.cpython-311.pyc +0 -0
- models/__pycache__/deberta_model.cpython-311.pyc +0 -0
- models/__pycache__/parallel_bert_deberta.cpython-311.pyc +0 -0
- models/__pycache__/roberta_model.cpython-311.pyc +0 -0
- models/__pycache__/text_and_metadata_model.cpython-311.pyc +0 -0
- models/bert_model.py +59 -0
- models/deberta_model.py +55 -0
- models/parallel_bert_deberta.py +119 -0
- models/roberta_model.py +56 -0
- models/text_and_metadata_model.py +72 -0
models/__pycache__/bert_model.cpython-311.pyc
ADDED
|
Binary file (3.29 kB). View file
|
|
|
models/__pycache__/deberta_model.cpython-311.pyc
ADDED
|
Binary file (3.15 kB). View file
|
|
|
models/__pycache__/parallel_bert_deberta.cpython-311.pyc
ADDED
|
Binary file (6.45 kB). View file
|
|
|
models/__pycache__/roberta_model.cpython-311.pyc
ADDED
|
Binary file (3.18 kB). View file
|
|
|
models/__pycache__/text_and_metadata_model.cpython-311.pyc
ADDED
|
Binary file (4.09 kB). View file
|
|
|
models/bert_model.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/bert_model.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import BertModel
|
| 6 |
+
from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME from config
|
| 7 |
+
|
| 8 |
+
class BertMultiOutputModel(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
BERT-based model for multi-output classification.
|
| 11 |
+
It uses a pre-trained BERT model as its backbone and adds a dropout layer
|
| 12 |
+
followed by separate linear classification heads for each target label.
|
| 13 |
+
"""
|
| 14 |
+
# Statically set tokenizer name for easy access in main.py
|
| 15 |
+
tokenizer_name = BERT_MODEL_NAME
|
| 16 |
+
|
| 17 |
+
def __init__(self, num_labels):
|
| 18 |
+
"""
|
| 19 |
+
Initializes the BertMultiOutputModel.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
num_labels (list): A list where each element is the number of classes
|
| 23 |
+
for a corresponding label column.
|
| 24 |
+
"""
|
| 25 |
+
super(BertMultiOutputModel, self).__init__()
|
| 26 |
+
# Load the pre-trained BERT model.
|
| 27 |
+
# BertModel provides contextual embeddings and a pooled output for classification.
|
| 28 |
+
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
|
| 29 |
+
self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization
|
| 30 |
+
|
| 31 |
+
# Create a list of classification heads, one for each label column.
|
| 32 |
+
# Each head is a linear layer mapping BERT's pooled output size to the number of classes for that label.
|
| 33 |
+
self.classifiers = nn.ModuleList([
|
| 34 |
+
nn.Linear(self.bert.config.hidden_size, n_classes) for n_classes in num_labels
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
def forward(self, input_ids, attention_mask):
|
| 38 |
+
"""
|
| 39 |
+
Performs the forward pass of the model.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
input_ids (torch.Tensor): Tensor of token IDs (from tokenizer).
|
| 43 |
+
attention_mask (torch.Tensor): Tensor indicating attention (from tokenizer).
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
list: A list of logit tensors, one for each classification head.
|
| 47 |
+
Each tensor has shape (batch_size, num_classes_for_that_label).
|
| 48 |
+
"""
|
| 49 |
+
# Pass input_ids and attention_mask through BERT.
|
| 50 |
+
# .pooler_output typically represents the hidden state of the [CLS] token,
|
| 51 |
+
# processed through a linear layer and tanh activation, often used for classification.
|
| 52 |
+
pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
|
| 53 |
+
|
| 54 |
+
# Apply dropout for regularization
|
| 55 |
+
pooled_output = self.dropout(pooled_output)
|
| 56 |
+
|
| 57 |
+
# Pass the pooled output through each classification head.
|
| 58 |
+
# The result is a list of logits (raw scores before softmax/sigmoid) for each label.
|
| 59 |
+
return [classifier(pooled_output) for classifier in self.classifiers]
|
models/deberta_model.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/deberta_model.py
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from transformers import DebertaModel
|
| 5 |
+
from config import DROPOUT_RATE, DEBERTA_MODEL_NAME # Import DEBERTA_MODEL_NAME
|
| 6 |
+
|
| 7 |
+
class DebertaMultiOutputModel(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
DeBERTa-based model for multi-output classification.
|
| 10 |
+
Similar structure to the BERT model, using a pre-trained DeBERTa model
|
| 11 |
+
as the backbone for text feature extraction.
|
| 12 |
+
"""
|
| 13 |
+
# Statically set tokenizer name for easy access in main.py
|
| 14 |
+
tokenizer_name = DEBERTA_MODEL_NAME
|
| 15 |
+
|
| 16 |
+
def __init__(self, num_labels):
|
| 17 |
+
"""
|
| 18 |
+
Initializes the DebertaMultiOutputModel.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
num_labels (list): A list where each element is the number of classes
|
| 22 |
+
for a corresponding label column.
|
| 23 |
+
"""
|
| 24 |
+
super(DebertaMultiOutputModel, self).__init__()
|
| 25 |
+
# Load the pre-trained DeBERTa model.
|
| 26 |
+
# DeBERTa models typically also provide a 'pooler_output' which is suitable for classification.
|
| 27 |
+
self.deberta = DebertaModel.from_pretrained(DEBERTA_MODEL_NAME)
|
| 28 |
+
self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization
|
| 29 |
+
|
| 30 |
+
# Create classification heads for each label column.
|
| 31 |
+
# Each head maps DeBERTa's pooled output size to the number of classes for that label.
|
| 32 |
+
self.classifiers = nn.ModuleList([
|
| 33 |
+
nn.Linear(self.deberta.config.hidden_size, n_classes) for n_classes in num_labels
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
def forward(self, input_ids, attention_mask):
|
| 37 |
+
"""
|
| 38 |
+
Performs the forward pass of the model.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
input_ids (torch.Tensor): Tensor of token IDs.
|
| 42 |
+
attention_mask (torch.Tensor): Tensor indicating attention.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
list: A list of logit tensors, one for each classification head.
|
| 46 |
+
"""
|
| 47 |
+
# Pass input_ids and attention_mask through DeBERTa.
|
| 48 |
+
# .pooler_output is used here, similar to BERT.
|
| 49 |
+
pooled_output = self.deberta(input_ids=input_ids, attention_mask=attention_mask).pooler_output
|
| 50 |
+
|
| 51 |
+
# Apply dropout
|
| 52 |
+
pooled_output = self.dropout(pooled_output)
|
| 53 |
+
|
| 54 |
+
# Pass the pooled output through each classification head.
|
| 55 |
+
return [classifier(pooled_output) for classifier in self.classifiers]
|
models/parallel_bert_deberta.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/parallel_bert_deberta.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import BertModel, DebertaModel
|
| 6 |
+
from config import DROPOUT_RATE, BERT_MODEL_NAME, DEBERTA_MODEL_NAME # Import model names
|
| 7 |
+
|
| 8 |
+
class Attention(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Simple Attention layer to compute a context vector from a sequence of hidden states.
|
| 11 |
+
It learns a single weight for each hidden state in the sequence, then uses softmax
|
| 12 |
+
to normalize these weights and compute a weighted sum of the hidden states.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, hidden_size):
|
| 15 |
+
"""
|
| 16 |
+
Initializes the Attention layer.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
hidden_size (int): The dimensionality of the input hidden states.
|
| 20 |
+
"""
|
| 21 |
+
super(Attention, self).__init__()
|
| 22 |
+
# A linear layer to project the hidden state to a single scalar (attention score)
|
| 23 |
+
self.attn = nn.Linear(hidden_size, 1)
|
| 24 |
+
|
| 25 |
+
def forward(self, encoder_output):
|
| 26 |
+
"""
|
| 27 |
+
Performs the forward pass of the attention mechanism.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
encoder_output (torch.Tensor): Tensor of hidden states from an encoder.
|
| 31 |
+
Shape: (batch_size, sequence_length, hidden_size)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
torch.Tensor: The context vector, a weighted sum of the hidden states.
|
| 35 |
+
Shape: (batch_size, hidden_size)
|
| 36 |
+
"""
|
| 37 |
+
# Calculate raw attention scores
|
| 38 |
+
# self.attn(encoder_output) -> (batch_size, sequence_length, 1)
|
| 39 |
+
# .squeeze(-1) removes the last dimension, making it (batch_size, sequence_length)
|
| 40 |
+
attn_weights = torch.softmax(self.attn(encoder_output).squeeze(-1), dim=1)
|
| 41 |
+
|
| 42 |
+
# Compute the context vector as a weighted sum of encoder_output.
|
| 43 |
+
# attn_weights.unsqueeze(-1) adds a dimension for broadcasting: (batch_size, sequence_length, 1)
|
| 44 |
+
# This allows element-wise multiplication with encoder_output.
|
| 45 |
+
# torch.sum(..., dim=1) sums along the sequence_length dimension.
|
| 46 |
+
context_vector = torch.sum(attn_weights.unsqueeze(-1) * encoder_output, dim=1)
|
| 47 |
+
return context_vector
|
| 48 |
+
|
| 49 |
+
class ParallelMultiOutputModel(nn.Module):
|
| 50 |
+
"""
|
| 51 |
+
Hybrid model that leverages both BERT and DeBERTa in parallel.
|
| 52 |
+
It extracts features from both models, applies an attention mechanism to their outputs,
|
| 53 |
+
projects these attended features to a common dimension, concatenates them, and then
|
| 54 |
+
uses this combined representation for multi-output classification.
|
| 55 |
+
"""
|
| 56 |
+
# Statically set tokenizer name to BERT's for this combined model
|
| 57 |
+
# (assuming BERT's tokenizer is compatible or primary for combined input)
|
| 58 |
+
tokenizer_name = BERT_MODEL_NAME
|
| 59 |
+
|
| 60 |
+
def __init__(self, num_labels):
|
| 61 |
+
"""
|
| 62 |
+
Initializes the ParallelMultiOutputModel.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
num_labels (list): A list where each element is the number of classes
|
| 66 |
+
for a corresponding label column.
|
| 67 |
+
"""
|
| 68 |
+
super(ParallelMultiOutputModel, self).__init__()
|
| 69 |
+
# Load pre-trained BERT and DeBERTa models
|
| 70 |
+
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
|
| 71 |
+
self.deberta = DebertaModel.from_pretrained(DEBERTA_MODEL_NAME)
|
| 72 |
+
|
| 73 |
+
# Initialize attention layers for each backbone model
|
| 74 |
+
self.attn_bert = Attention(self.bert.config.hidden_size)
|
| 75 |
+
self.attn_deberta = Attention(self.deberta.config.hidden_size)
|
| 76 |
+
|
| 77 |
+
# Projection layers to reduce dimensionality of the context vectors
|
| 78 |
+
# before concatenation. This helps manage the combined feature size.
|
| 79 |
+
self.proj_bert = nn.Linear(self.bert.config.hidden_size, 256)
|
| 80 |
+
self.proj_deberta = nn.Linear(self.deberta.config.hidden_size, 256)
|
| 81 |
+
|
| 82 |
+
self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization
|
| 83 |
+
|
| 84 |
+
# Define classification heads. The input feature size is the sum of
|
| 85 |
+
# the projected sizes from BERT and DeBERTa (256 + 256 = 512).
|
| 86 |
+
self.classifiers = nn.ModuleList([
|
| 87 |
+
nn.Linear(512, n_classes) for n_classes in num_labels
|
| 88 |
+
])
|
| 89 |
+
|
| 90 |
+
def forward(self, input_ids, attention_mask):
|
| 91 |
+
"""
|
| 92 |
+
Performs the forward pass of the parallel model.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
input_ids (torch.Tensor): Tensor of token IDs.
|
| 96 |
+
attention_mask (torch.Tensor): Tensor indicating attention.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
list: A list of logit tensors, one for each classification head.
|
| 100 |
+
"""
|
| 101 |
+
# Get the last hidden states (sequence of hidden states for all tokens)
|
| 102 |
+
# from both BERT and DeBERTa. These are typically used with attention.
|
| 103 |
+
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
|
| 104 |
+
deberta_output = self.deberta(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
|
| 105 |
+
|
| 106 |
+
# Apply attention to get a single context vector from each model's output
|
| 107 |
+
context_bert = self.attn_bert(bert_output)
|
| 108 |
+
context_deberta = self.attn_deberta(deberta_output)
|
| 109 |
+
|
| 110 |
+
# Project the context vectors to their reduced dimensions
|
| 111 |
+
reduced_bert = self.proj_bert(context_bert)
|
| 112 |
+
reduced_deberta = self.proj_deberta(context_deberta)
|
| 113 |
+
|
| 114 |
+
# Concatenate the reduced feature vectors from both models
|
| 115 |
+
combined = torch.cat((reduced_bert, reduced_deberta), dim=1)
|
| 116 |
+
combined = self.dropout(combined) # Apply dropout to the combined features
|
| 117 |
+
|
| 118 |
+
# Pass the combined features through each classification head
|
| 119 |
+
return [classifier(combined) for classifier in self.classifiers]
|
models/roberta_model.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/roberta_model.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import RobertaModel
|
| 6 |
+
from config import DROPOUT_RATE, ROBERTA_MODEL_NAME # Import ROBERTA_MODEL_NAME
|
| 7 |
+
|
| 8 |
+
class RobertaMultiOutputModel(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
RoBERTa-based model for multi-output classification.
|
| 11 |
+
Uses a pre-trained RoBERTa model as its backbone. RoBERTa is an optimized
|
| 12 |
+
version of BERT, often performing better.
|
| 13 |
+
"""
|
| 14 |
+
# Statically set tokenizer name for easy access in main.py
|
| 15 |
+
tokenizer_name = ROBERTA_MODEL_NAME
|
| 16 |
+
|
| 17 |
+
def __init__(self, num_labels):
|
| 18 |
+
"""
|
| 19 |
+
Initializes the RobertaMultiOutputModel.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
num_labels (list): A list where each element is the number of classes
|
| 23 |
+
for a corresponding label column.
|
| 24 |
+
"""
|
| 25 |
+
super(RobertaMultiOutputModel, self).__init__()
|
| 26 |
+
# Load the pre-trained RoBERTa model.
|
| 27 |
+
# RoBERTa's pooler_output typically corresponds to the hidden state of the
|
| 28 |
+
# first token (<s>), which is often used for sequence classification.
|
| 29 |
+
self.roberta = RobertaModel.from_pretrained(ROBERTA_MODEL_NAME)
|
| 30 |
+
self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer
|
| 31 |
+
|
| 32 |
+
# Create classification heads for each label column.
|
| 33 |
+
self.classifiers = nn.ModuleList([
|
| 34 |
+
nn.Linear(self.roberta.config.hidden_size, n_classes) for n_classes in num_labels
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
def forward(self, input_ids, attention_mask):
|
| 38 |
+
"""
|
| 39 |
+
Performs the forward pass of the model.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
input_ids (torch.Tensor): Tensor of token IDs.
|
| 43 |
+
attention_mask (torch.Tensor): Tensor indicating attention.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
list: A list of logit tensors, one for each classification head.
|
| 47 |
+
"""
|
| 48 |
+
# Pass input_ids and attention_mask through RoBERTa.
|
| 49 |
+
# .pooler_output is used for classification.
|
| 50 |
+
pooled_output = self.roberta(input_ids=input_ids, attention_mask=attention_mask).pooler_output
|
| 51 |
+
|
| 52 |
+
# Apply dropout
|
| 53 |
+
pooled_output = self.dropout(pooled_output)
|
| 54 |
+
|
| 55 |
+
# Pass the pooled output through each classification head.
|
| 56 |
+
return [classifier(pooled_output) for classifier in self.classifiers]
|
models/text_and_metadata_model.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/text_and_metadata_model.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import BertModel # Can be extended to RoBERTa, DeBERTa etc.
|
| 6 |
+
from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME
|
| 7 |
+
|
| 8 |
+
class BertWithMetadataModel(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Hybrid model that combines text features (extracted by BERT) with additional
|
| 11 |
+
numerical metadata features. The text features are processed by BERT,
|
| 12 |
+
metadata features by a simple MLP, and then their outputs are concatenated
|
| 13 |
+
before being fed into the final classification heads.
|
| 14 |
+
"""
|
| 15 |
+
# Statically set tokenizer name
|
| 16 |
+
tokenizer_name = BERT_MODEL_NAME
|
| 17 |
+
|
| 18 |
+
def __init__(self, num_labels, metadata_dim):
|
| 19 |
+
"""
|
| 20 |
+
Initializes the BertWithMetadataModel.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
num_labels (list): A list where each element is the number of classes
|
| 24 |
+
for a corresponding label column.
|
| 25 |
+
metadata_dim (int): The number of features in the numerical metadata.
|
| 26 |
+
"""
|
| 27 |
+
super(BertWithMetadataModel, self).__init__()
|
| 28 |
+
# Load pre-trained BERT model for text processing
|
| 29 |
+
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
|
| 30 |
+
self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout for BERT's output
|
| 31 |
+
|
| 32 |
+
# MLP for processing numerical metadata features
|
| 33 |
+
self.metadata_mlp = nn.Sequential(
|
| 34 |
+
nn.Linear(metadata_dim, 128), # First linear layer
|
| 35 |
+
nn.ReLU(), # Activation function
|
| 36 |
+
nn.Dropout(DROPOUT_RATE), # Dropout for metadata features
|
| 37 |
+
nn.Linear(128, 64) # Second linear layer
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Calculate the total input feature size for the classification heads.
|
| 41 |
+
# This is the sum of BERT's pooled output size and the metadata MLP's output size.
|
| 42 |
+
combined_feature_size = self.bert.config.hidden_size + 64
|
| 43 |
+
|
| 44 |
+
# Create classification heads, one for each label column
|
| 45 |
+
self.classifiers = nn.ModuleList([
|
| 46 |
+
nn.Linear(combined_feature_size, n_classes) for n_classes in num_labels
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
def forward(self, input_ids, attention_mask, metadata):
|
| 50 |
+
"""
|
| 51 |
+
Performs the forward pass of the hybrid model.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
input_ids (torch.Tensor): Tensor of token IDs for text.
|
| 55 |
+
attention_mask (torch.Tensor): Tensor indicating attention for text.
|
| 56 |
+
metadata (torch.Tensor): Tensor of numerical metadata features.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
list: A list of logit tensors, one for each classification head.
|
| 60 |
+
"""
|
| 61 |
+
# Process text input through BERT
|
| 62 |
+
bert_pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
|
| 63 |
+
bert_pooled_output = self.dropout(bert_pooled_output) # Apply dropout
|
| 64 |
+
|
| 65 |
+
# Process metadata through the MLP
|
| 66 |
+
metadata_output = self.metadata_mlp(metadata)
|
| 67 |
+
|
| 68 |
+
# Concatenate the processed text features and metadata features
|
| 69 |
+
combined_features = torch.cat((bert_pooled_output, metadata_output), dim=1)
|
| 70 |
+
|
| 71 |
+
# Pass the combined features through each classification head
|
| 72 |
+
return [classifier(combined_features) for classifier in self.classifiers]
|