Spaces:
Runtime error
Runtime error
models/bert_model.py
Browse files- bert_model.py +59 -0
bert_model.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/bert_model.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import BertModel
|
| 6 |
+
from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME from config
|
| 7 |
+
|
| 8 |
+
class BertMultiOutputModel(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
BERT-based model for multi-output classification.
|
| 11 |
+
It uses a pre-trained BERT model as its backbone and adds a dropout layer
|
| 12 |
+
followed by separate linear classification heads for each target label.
|
| 13 |
+
"""
|
| 14 |
+
# Statically set tokenizer name for easy access in main.py
|
| 15 |
+
tokenizer_name = BERT_MODEL_NAME
|
| 16 |
+
|
| 17 |
+
def __init__(self, num_labels):
|
| 18 |
+
"""
|
| 19 |
+
Initializes the BertMultiOutputModel.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
num_labels (list): A list where each element is the number of classes
|
| 23 |
+
for a corresponding label column.
|
| 24 |
+
"""
|
| 25 |
+
super(BertMultiOutputModel, self).__init__()
|
| 26 |
+
# Load the pre-trained BERT model.
|
| 27 |
+
# BertModel provides contextual embeddings and a pooled output for classification.
|
| 28 |
+
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
|
| 29 |
+
self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization
|
| 30 |
+
|
| 31 |
+
# Create a list of classification heads, one for each label column.
|
| 32 |
+
# Each head is a linear layer mapping BERT's pooled output size to the number of classes for that label.
|
| 33 |
+
self.classifiers = nn.ModuleList([
|
| 34 |
+
nn.Linear(self.bert.config.hidden_size, n_classes) for n_classes in num_labels
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
def forward(self, input_ids, attention_mask):
|
| 38 |
+
"""
|
| 39 |
+
Performs the forward pass of the model.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
input_ids (torch.Tensor): Tensor of token IDs (from tokenizer).
|
| 43 |
+
attention_mask (torch.Tensor): Tensor indicating attention (from tokenizer).
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
list: A list of logit tensors, one for each classification head.
|
| 47 |
+
Each tensor has shape (batch_size, num_classes_for_that_label).
|
| 48 |
+
"""
|
| 49 |
+
# Pass input_ids and attention_mask through BERT.
|
| 50 |
+
# .pooler_output typically represents the hidden state of the [CLS] token,
|
| 51 |
+
# processed through a linear layer and tanh activation, often used for classification.
|
| 52 |
+
pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
|
| 53 |
+
|
| 54 |
+
# Apply dropout for regularization
|
| 55 |
+
pooled_output = self.dropout(pooled_output)
|
| 56 |
+
|
| 57 |
+
# Pass the pooled output through each classification head.
|
| 58 |
+
# The result is a list of logits (raw scores before softmax/sigmoid) for each label.
|
| 59 |
+
return [classifier(pooled_output) for classifier in self.classifiers]
|