subbu123456 commited on
Commit
11a9e1e
·
verified ·
1 Parent(s): 6d76c14

models/bert_model.py

Browse files
Files changed (1) hide show
  1. bert_model.py +59 -0
bert_model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/bert_model.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import BertModel
6
+ from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME from config
7
+
8
+ class BertMultiOutputModel(nn.Module):
9
+ """
10
+ BERT-based model for multi-output classification.
11
+ It uses a pre-trained BERT model as its backbone and adds a dropout layer
12
+ followed by separate linear classification heads for each target label.
13
+ """
14
+ # Statically set tokenizer name for easy access in main.py
15
+ tokenizer_name = BERT_MODEL_NAME
16
+
17
+ def __init__(self, num_labels):
18
+ """
19
+ Initializes the BertMultiOutputModel.
20
+
21
+ Args:
22
+ num_labels (list): A list where each element is the number of classes
23
+ for a corresponding label column.
24
+ """
25
+ super(BertMultiOutputModel, self).__init__()
26
+ # Load the pre-trained BERT model.
27
+ # BertModel provides contextual embeddings and a pooled output for classification.
28
+ self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
29
+ self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization
30
+
31
+ # Create a list of classification heads, one for each label column.
32
+ # Each head is a linear layer mapping BERT's pooled output size to the number of classes for that label.
33
+ self.classifiers = nn.ModuleList([
34
+ nn.Linear(self.bert.config.hidden_size, n_classes) for n_classes in num_labels
35
+ ])
36
+
37
+ def forward(self, input_ids, attention_mask):
38
+ """
39
+ Performs the forward pass of the model.
40
+
41
+ Args:
42
+ input_ids (torch.Tensor): Tensor of token IDs (from tokenizer).
43
+ attention_mask (torch.Tensor): Tensor indicating attention (from tokenizer).
44
+
45
+ Returns:
46
+ list: A list of logit tensors, one for each classification head.
47
+ Each tensor has shape (batch_size, num_classes_for_that_label).
48
+ """
49
+ # Pass input_ids and attention_mask through BERT.
50
+ # .pooler_output typically represents the hidden state of the [CLS] token,
51
+ # processed through a linear layer and tanh activation, often used for classification.
52
+ pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
53
+
54
+ # Apply dropout for regularization
55
+ pooled_output = self.dropout(pooled_output)
56
+
57
+ # Pass the pooled output through each classification head.
58
+ # The result is a list of logits (raw scores before softmax/sigmoid) for each label.
59
+ return [classifier(pooled_output) for classifier in self.classifiers]