3v324v23 commited on
Commit
acbb45f
·
1 Parent(s): 3fd3cf4

Adding curriculum face model

Browse files
Files changed (1) hide show
  1. models.py +224 -13
models.py CHANGED
@@ -6,11 +6,26 @@ import os
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
9
  from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
10
  from transformers.modeling_outputs import SequenceClassifierOutput
11
  from transformers.utils.hub import cached_file
 
12
  #from prokbert.training_utils import compute_metrics_eval_prediction
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class BertForBinaryClassificationWithPooling(nn.Module):
15
  """
16
  ProkBERT model for binary classification with custom pooling.
@@ -128,9 +143,6 @@ class BertForBinaryClassificationWithPooling(nn.Module):
128
 
129
  return model
130
 
131
-
132
-
133
-
134
  class ProkBertConfig(MegatronBertConfig):
135
  model_type = "prokbert"
136
 
@@ -138,18 +150,36 @@ class ProkBertConfig(MegatronBertConfig):
138
  self,
139
  kmer: int = 6,
140
  shift: int = 1,
141
- num_labels: int = 2,
142
  classification_dropout_rate: float = 0.1,
143
  **kwargs,
144
  ):
145
  super().__init__(**kwargs)
146
  self.kmer = kmer
147
  self.shift = shift
148
- self.num_labels = num_labels
149
  self.classification_dropout_rate = classification_dropout_rate
150
 
 
 
151
 
152
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  class ProkBertClassificationConfig(ProkBertConfig):
155
  model_type = "prokbert"
@@ -186,9 +216,6 @@ class ProkBertPreTrainedModel(PreTrainedModel):
186
  if isinstance(module, nn.Linear) and module.bias is not None:
187
  module.bias.data.zero_()
188
 
189
-
190
-
191
-
192
  class ProkBertModel(MegatronBertModel):
193
  config_class = ProkBertConfig
194
 
@@ -224,7 +251,7 @@ class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
224
  self.bert = ProkBertModel(config)
225
  self.weighting_layer = nn.Linear(self.config.hidden_size, 1)
226
  self.dropout = nn.Dropout(self.config.classification_dropout_rate)
227
- self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)
228
  self.loss_fct = torch.nn.CrossEntropyLoss()
229
 
230
  self.post_init()
@@ -245,8 +272,8 @@ class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
245
  r"""
246
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
247
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
248
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
249
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
250
  """
251
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
252
 
@@ -273,7 +300,7 @@ class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
273
  logits = self.classifier(pooled_output)
274
  loss = None
275
  if labels is not None:
276
- loss = self.loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
277
 
278
  classification_output = SequenceClassifierOutput(
279
  loss=loss,
@@ -283,3 +310,187 @@ class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
283
  )
284
  return classification_output
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
+ from torch.nn.parameter import Parameter
10
  from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
11
  from transformers.modeling_outputs import SequenceClassifierOutput
12
  from transformers.utils.hub import cached_file
13
+ import math
14
  #from prokbert.training_utils import compute_metrics_eval_prediction
15
 
16
+
17
+ def l2_norm(input, axis=1, epsilon=1e-12):
18
+ norm = torch.norm(input, 2, axis, True)
19
+ norm = torch.clamp(norm, min=epsilon) # Avoid zero division
20
+ output = torch.div(input, norm)
21
+ return output
22
+
23
+ def initialize_linear_kaiming(layer: nn.Linear):
24
+ if isinstance(layer, nn.Linear):
25
+ nn.init.kaiming_uniform_(layer.weight, nonlinearity='linear')
26
+ if layer.bias is not None:
27
+ nn.init.zeros_(layer.bias)
28
+
29
  class BertForBinaryClassificationWithPooling(nn.Module):
30
  """
31
  ProkBERT model for binary classification with custom pooling.
 
143
 
144
  return model
145
 
 
 
 
146
  class ProkBertConfig(MegatronBertConfig):
147
  model_type = "prokbert"
148
 
 
150
  self,
151
  kmer: int = 6,
152
  shift: int = 1,
153
+ num_class_labels: int = 2,
154
  classification_dropout_rate: float = 0.1,
155
  **kwargs,
156
  ):
157
  super().__init__(**kwargs)
158
  self.kmer = kmer
159
  self.shift = shift
160
+ self.num_class_labels = num_class_labels
161
  self.classification_dropout_rate = classification_dropout_rate
162
 
163
+ class ProkBertConfigCurr(ProkBertConfig):
164
+ model_type = "prokbert"
165
 
166
+ def __init__(
167
+ self,
168
+ bert_base_model = "neuralbioinfo/prokbert-mini",
169
+ curricular_face_m = 0.5,
170
+ curricular_face_s=64.,
171
+ curricular_num_labels = 2,
172
+ curriculum_hidden_size = -1,
173
+ classification_dropout_rate = 0.0,
174
+ **kwargs,
175
+ ):
176
+ super().__init__( **kwargs)
177
+ self.curricular_num_labels = curricular_num_labels
178
+ self.curricular_face_m = curricular_face_m
179
+ self.curricular_face_s = curricular_face_s
180
+ self.bert_base_model = bert_base_model
181
+ self.curriculum_hidden_size = curriculum_hidden_size
182
+ self.classification_dropout_rate = classification_dropout_rate
183
 
184
  class ProkBertClassificationConfig(ProkBertConfig):
185
  model_type = "prokbert"
 
216
  if isinstance(module, nn.Linear) and module.bias is not None:
217
  module.bias.data.zero_()
218
 
 
 
 
219
  class ProkBertModel(MegatronBertModel):
220
  config_class = ProkBertConfig
221
 
 
251
  self.bert = ProkBertModel(config)
252
  self.weighting_layer = nn.Linear(self.config.hidden_size, 1)
253
  self.dropout = nn.Dropout(self.config.classification_dropout_rate)
254
+ self.classifier = nn.Linear(self.config.hidden_size, self.config.num_class_labels)
255
  self.loss_fct = torch.nn.CrossEntropyLoss()
256
 
257
  self.post_init()
 
272
  r"""
273
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
274
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
275
+ config.num_labels - 1]`. If `config.num_class_labels == 1` a regression loss is computed (Mean-Square loss), If
276
+ `config.num_class_labels > 1` a classification loss is computed (Cross-Entropy).
277
  """
278
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
279
 
 
300
  logits = self.classifier(pooled_output)
301
  loss = None
302
  if labels is not None:
303
+ loss = self.loss_fct(logits.view(-1, self.config.num_class_labels), labels.view(-1))
304
 
305
  classification_output = SequenceClassifierOutput(
306
  loss=loss,
 
310
  )
311
  return classification_output
312
 
313
+ class CurricularFace(nn.Module):
314
+ def __init__(self, in_features, out_features, m=0.5, s=64.):
315
+ super(CurricularFace, self).__init__()
316
+ self.in_features = in_features
317
+ self.out_features = out_features
318
+ self.m = m
319
+ self.s = s
320
+ self.cos_m = math.cos(m)
321
+ self.sin_m = math.sin(m)
322
+ self.threshold = math.cos(math.pi - m)
323
+ self.mm = math.sin(math.pi - m) * m
324
+ self.kernel = Parameter(torch.Tensor(in_features, out_features))
325
+ self.register_buffer('t', torch.zeros(1))
326
+
327
+ def forward(self, embeddings, label):
328
+ # Normalize embeddings and the classifier kernel
329
+ embeddings = l2_norm(embeddings, axis=1)
330
+ kernel_norm = l2_norm(self.kernel, axis=0)
331
+ # Compute cosine similarity between embeddings and kernel columns
332
+ cos_theta = torch.mm(embeddings, kernel_norm)
333
+ cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
334
+
335
+ # print(f"cos theta")
336
+ # print(cos_theta)
337
+
338
+ # Clone original cosine values (used later for analysis if needed)
339
+ with torch.no_grad():
340
+ origin_cos = cos_theta.clone()
341
+
342
+ # Get the cosine values corresponding to the ground-truth classes
343
+ target_logit = cos_theta[torch.arange(0, embeddings.size(0)), label].view(-1, 1)
344
+ sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
345
+ cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target + margin)
346
+
347
+ # Create a mask for positions where the cosine similarity exceeds the modified value
348
+ mask = (cos_theta > cos_theta_m) #.to(dtype=torch.uint8)
349
+
350
+ # Apply the margin condition: for values greater than threshold, use cosine with margin;
351
+ # otherwise subtract a fixed term.
352
+ final_target_logit = torch.where(target_logit > self.threshold,
353
+ cos_theta_m,
354
+ target_logit - self.mm)
355
+
356
+ # Update the buffer 't' (used to control the weight of hard examples)
357
+ with torch.no_grad():
358
+ self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
359
+
360
+ # For the positions in the mask, re-scale the logits
361
+ try:
362
+ hard_example = cos_theta[mask]
363
+ except Exception as e:
364
+ print("Label max")
365
+ print(torch.max(label))
366
+ print("Shapes:")
367
+ print(embeddings.shape)
368
+ print(label.shape)
369
+ hard_example = cos_theta[mask]
370
+
371
+ cos_theta[mask] = hard_example * (self.t + hard_example)
372
+
373
+ # Replace the logits of the target classes with the modified target logit
374
+ final_target_logit = final_target_logit.to(cos_theta.dtype)
375
+ cos_theta.scatter_(1, label.view(-1, 1).long(), final_target_logit)
376
+ output = cos_theta * self.s
377
+ return output, origin_cos * self.s
378
+
379
+ class ProkBertForCurricularClassification(ProkBertPreTrainedModel):
380
+ config_class = ProkBertConfigCurr
381
+ base_model_prefix = "bert"
382
+
383
+ def __init__(self, config):
384
+ super().__init__(config)
385
+ self.config = config
386
+ self.bert = ProkBertModel(config)
387
+
388
+ # A weighting layer for pooling the sequence output
389
+ self.weighting_layer = nn.Linear(self.config.hidden_size, 1)
390
+ self.dropout = nn.Dropout(self.config.classification_dropout_rate)
391
+
392
+ if config.curriculum_hidden_size != -1:
393
+ self.linear = nn.Linear(self.config.hidden_size, config.curriculum_hidden_size)
394
+
395
+ # Replace the simple classifier with the CurricularFace head.
396
+ # Defaults m=0.5 and s=64 are used, but these can be adjusted if needed.
397
+ self.curricular_face = CurricularFace(config.curriculum_hidden_size,
398
+ self.config.curricular_num_labels,
399
+ m=self.config.curricular_face_m,
400
+ s=self.config.curricular_face_s)
401
+ else:
402
+ self.linear = nn.Identity()
403
+ self.curricular_face = CurricularFace(self.config.hidden_size,
404
+ self.config.curricular_num_labels,
405
+ m=self.config.curricular_face_m,
406
+ s=self.config.curricular_face_s)
407
+
408
+
409
+ self.loss_fct = torch.nn.CrossEntropyLoss()
410
+ self.post_init()
411
+
412
+ def _init_weights(self, module: nn.Module):
413
+ # first let the base class init everything else
414
+ super()._init_weights(module)
415
+
416
+ # then catch our pooling head and zero it
417
+ if module is getattr(self, "weighting_layer", None):
418
+ nn.init.xavier_uniform_(module.weight)
419
+ nn.init.zeros_(module.bias)
420
+
421
+ if module is getattr(self, "linear", None):
422
+ initialize_linear_kaiming(module)
423
+
424
+ if module is getattr(self, "curricular_face", None):
425
+ nn.init.kaiming_uniform_(module.kernel, a=math.sqrt(self.config.curricular_num_labels))
426
+
427
+
428
+ def forward(
429
+ self,
430
+ input_ids: Optional[torch.LongTensor] = None,
431
+ attention_mask: Optional[torch.FloatTensor] = None,
432
+ token_type_ids: Optional[torch.LongTensor] = None,
433
+ position_ids: Optional[torch.LongTensor] = None,
434
+ head_mask: Optional[torch.FloatTensor] = None,
435
+ inputs_embeds: Optional[torch.FloatTensor] = None,
436
+ labels: Optional[torch.LongTensor] = None,
437
+ output_attentions: Optional[bool] = None,
438
+ output_hidden_states: Optional[bool] = None,
439
+ return_dict: Optional[bool] = None,
440
+ ) -> Union[Tuple, SequenceClassifierOutput]:
441
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
442
+
443
+ # Get the outputs from the base ProkBert model
444
+ outputs = self.bert(
445
+ input_ids,
446
+ attention_mask=attention_mask,
447
+ token_type_ids=token_type_ids,
448
+ position_ids=position_ids,
449
+ head_mask=head_mask,
450
+ inputs_embeds=inputs_embeds,
451
+ output_attentions=output_attentions,
452
+ output_hidden_states=output_hidden_states,
453
+ return_dict=return_dict,
454
+ )
455
+ sequence_output = outputs[0] # (batch_size, seq_length, hidden_size)
456
+
457
+ # Pool the sequence output using a learned weighting (attention-like)
458
+ weights = self.weighting_layer(sequence_output) # (batch_size, seq_length, 1)
459
+ # Ensure mask shape matches
460
+ if attention_mask.dim() == 2:
461
+ mask = attention_mask
462
+ elif attention_mask.dim() == 4:
463
+ mask = attention_mask.squeeze(1).squeeze(1) # (batch_size, seq_length)
464
+ else:
465
+ raise ValueError(f"Unexpected attention_mask shape {attention_mask.shape}")
466
+
467
+ # Apply mask (masked positions -> -inf before softmax)
468
+ weights = weights.masked_fill(mask.unsqueeze(-1) == 0, float('-inf'))
469
+
470
+ # Normalize
471
+ weights = torch.nn.functional.softmax(weights, dim=1) # (batch_size, seq_length)
472
+
473
+ # Weighted pooling
474
+ #weights = weights.unsqueeze(-1) # (batch_size, seq_length, 1)
475
+ pooled_output = torch.sum(weights * sequence_output, dim=1) # (batch_size, hidden_size)
476
+ # Classifier head
477
+ pooled_output = self.dropout(pooled_output)
478
+ pooled_output = self.linear(pooled_output)
479
+
480
+ # CurricularFace requires the embeddings and the corresponding labels.
481
+ # Note: During inference (labels is None), we just return l2 norm of bert part of the model
482
+ if labels is None:
483
+ return l2_norm(pooled_output, axis = 1)
484
+ else:
485
+ logits, origin_cos = self.curricular_face(pooled_output, labels)
486
+
487
+ loss = None
488
+ if labels is not None:
489
+ loss = self.loss_fct(logits, labels.view(-1))
490
+
491
+ return SequenceClassifierOutput(
492
+ loss=loss,
493
+ logits=logits,
494
+ hidden_states=outputs.hidden_states,
495
+ attentions=outputs.attentions,
496
+ )