voxmenthe commited on
Commit
50006df
·
verified ·
1 Parent(s): 9d0afef

Upload src/models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/models.py +172 -0
src/models.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ModernBertModel, ModernBertPreTrainedModel
2
+ from transformers.modeling_outputs import SequenceClassifierOutput
3
+ from torch import nn
4
+ import torch
5
+ from src.train_utils import SentimentWeightedLoss, SentimentFocalLoss
6
+ import torch.nn.functional as F
7
+
8
+ from src.classifiers import ClassifierHead, ConcatClassifierHead
9
+
10
+
11
+ class ModernBertForSentiment(ModernBertPreTrainedModel):
12
+ """ModernBERT encoder with a dynamically configurable classification head and pooling strategy."""
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.num_labels = config.num_labels
17
+ self.bert = ModernBertModel(config) # Base BERT model, config may have output_hidden_states=True
18
+
19
+ # Store pooling strategy from config
20
+ self.pooling_strategy = getattr(config, 'pooling_strategy', 'cls') # Default to 'cls'
21
+ self.num_weighted_layers = getattr(config, 'num_weighted_layers', 4)
22
+
23
+ if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat'] and not config.output_hidden_states:
24
+ # This check is more of an assertion; train.py should set output_hidden_states=True
25
+ raise ValueError(
26
+ "output_hidden_states must be True in BertConfig for weighted_layer pooling."
27
+ )
28
+
29
+ # Initialize weights for weighted layer pooling
30
+ if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat']:
31
+ # num_weighted_layers specifies how many *top* layers of BERT to use.
32
+ # If num_weighted_layers is e.g. 4, we use the last 4 layers.
33
+ self.layer_weights = nn.Parameter(torch.ones(self.num_weighted_layers) / self.num_weighted_layers)
34
+
35
+ # Determine classifier input size and choose head
36
+ classifier_input_size = config.hidden_size
37
+ if self.pooling_strategy in ['cls_mean_concat', 'cls_weighted_concat']:
38
+ classifier_input_size = config.hidden_size * 2
39
+
40
+ # Dropout for features fed into the classifier head
41
+ classifier_dropout_prob = (
42
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
43
+ )
44
+ self.features_dropout = nn.Dropout(classifier_dropout_prob)
45
+
46
+ # Select the appropriate classifier head based on input feature dimension
47
+ if classifier_input_size == config.hidden_size:
48
+ self.classifier = ClassifierHead(
49
+ hidden_size=config.hidden_size, # input_size for ClassifierHead is just hidden_size
50
+ num_labels=config.num_labels,
51
+ dropout_prob=classifier_dropout_prob
52
+ )
53
+ elif classifier_input_size == config.hidden_size * 2:
54
+ self.classifier = ConcatClassifierHead(
55
+ input_size=config.hidden_size * 2,
56
+ hidden_size=config.hidden_size, # Internal hidden size of the head
57
+ num_labels=config.num_labels,
58
+ dropout_prob=classifier_dropout_prob
59
+ )
60
+ else:
61
+ # This case should ideally not be reached with current strategies
62
+ raise ValueError(f"Unexpected classifier_input_size: {classifier_input_size}")
63
+
64
+ # Initialize loss function based on config
65
+ loss_config = getattr(config, 'loss_function', {'name': 'SentimentWeightedLoss', 'params': {}})
66
+ loss_name = loss_config.get('name', 'SentimentWeightedLoss')
67
+ loss_params = loss_config.get('params', {})
68
+
69
+ if loss_name == "SentimentWeightedLoss":
70
+ self.loss_fct = SentimentWeightedLoss() # SentimentWeightedLoss takes no arguments
71
+ elif loss_name == "SentimentFocalLoss":
72
+ # Ensure only relevant params are passed, or that loss_params is structured correctly for SentimentFocalLoss
73
+ # For SentimentFocalLoss, expected params are 'gamma_focal' and 'label_smoothing_epsilon'
74
+ self.loss_fct = SentimentFocalLoss(**loss_params)
75
+ else:
76
+ raise ValueError(f"Unsupported loss function: {loss_name}")
77
+
78
+ self.post_init() # Initialize weights and apply final processing
79
+
80
+ def _mean_pool(self, last_hidden_state, attention_mask):
81
+ if attention_mask is None:
82
+ attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) # Assuming first dim of last hidden state is token ids
83
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
84
+ sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
85
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
86
+ return sum_embeddings / sum_mask
87
+
88
+ def _weighted_layer_pool(self, all_hidden_states):
89
+ # all_hidden_states includes embeddings + output of each layer.
90
+ # We want the outputs of the last num_weighted_layers.
91
+ # Example: 12 layers -> all_hidden_states have 13 items (embeddings + 12 layers)
92
+ # num_weighted_layers = 4 -> use layers 9, 10, 11, 12 (indices -4, -3, -2, -1)
93
+ layers_to_weigh = torch.stack(all_hidden_states[-self.num_weighted_layers:], dim=0)
94
+ # layers_to_weigh shape: (num_weighted_layers, batch_size, sequence_length, hidden_size)
95
+
96
+ # Normalize weights to sum to 1 (softmax or simple division)
97
+ normalized_weights = F.softmax(self.layer_weights, dim=-1)
98
+
99
+ # Weighted sum across layers
100
+ # Reshape weights for broadcasting: (num_weighted_layers, 1, 1, 1)
101
+ weighted_hidden_states = layers_to_weigh * normalized_weights.view(-1, 1, 1, 1)
102
+ weighted_sum_hidden_states = torch.sum(weighted_hidden_states, dim=0)
103
+ # weighted_sum_hidden_states shape: (batch_size, sequence_length, hidden_size)
104
+
105
+ # Pool the result (e.g., take [CLS] token of this weighted sum)
106
+ return weighted_sum_hidden_states[:, 0] # Return CLS token of the weighted sum
107
+
108
+ def forward(
109
+ self,
110
+ input_ids=None,
111
+ attention_mask=None,
112
+ labels=None,
113
+ lengths=None,
114
+ return_dict=None,
115
+ **kwargs
116
+ ):
117
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
118
+
119
+ bert_outputs = self.bert(
120
+ input_ids,
121
+ attention_mask=attention_mask,
122
+ return_dict=return_dict,
123
+ output_hidden_states=self.config.output_hidden_states # Controlled by train.py
124
+ )
125
+
126
+ last_hidden_state = bert_outputs[0] # Or bert_outputs.last_hidden_state
127
+ pooled_features = None
128
+
129
+ if self.pooling_strategy == 'cls':
130
+ pooled_features = last_hidden_state[:, 0] # CLS token
131
+ elif self.pooling_strategy == 'mean':
132
+ pooled_features = self._mean_pool(last_hidden_state, attention_mask)
133
+ elif self.pooling_strategy == 'cls_mean_concat':
134
+ cls_output = last_hidden_state[:, 0]
135
+ mean_output = self._mean_pool(last_hidden_state, attention_mask)
136
+ pooled_features = torch.cat((cls_output, mean_output), dim=1)
137
+ elif self.pooling_strategy == 'weighted_layer':
138
+ if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
139
+ raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
140
+ all_hidden_states = bert_outputs.hidden_states
141
+ pooled_features = self._weighted_layer_pool(all_hidden_states)
142
+ elif self.pooling_strategy == 'cls_weighted_concat':
143
+ if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
144
+ raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
145
+ cls_output = last_hidden_state[:, 0]
146
+ all_hidden_states = bert_outputs.hidden_states
147
+ weighted_output = self._weighted_layer_pool(all_hidden_states)
148
+ pooled_features = torch.cat((cls_output, weighted_output), dim=1)
149
+ else:
150
+ raise ValueError(f"Unknown pooling_strategy: {self.pooling_strategy}")
151
+
152
+ pooled_features = self.features_dropout(pooled_features)
153
+ logits = self.classifier(pooled_features)
154
+
155
+ loss = None
156
+ if labels is not None:
157
+ if lengths is None:
158
+ raise ValueError("lengths must be provided when labels are specified for loss calculation.")
159
+ loss = self.loss_fct(logits.squeeze(-1), labels, lengths)
160
+
161
+ if not return_dict:
162
+ # Ensure 'outputs' from BERT is appropriately handled. If it's a tuple:
163
+ bert_model_outputs = bert_outputs[1:] if isinstance(bert_outputs, tuple) else (bert_outputs.hidden_states, bert_outputs.attentions)
164
+ output = (logits,) + bert_model_outputs
165
+ return ((loss,) + output) if loss is not None else output
166
+
167
+ return SequenceClassifierOutput(
168
+ loss=loss,
169
+ logits=logits,
170
+ hidden_states=bert_outputs.hidden_states,
171
+ attentions=bert_outputs.attentions,
172
+ )