TCMVince commited on
Commit
b153482
·
verified ·
1 Parent(s): 7a20355

Update mlm.py

Browse files
Files changed (1) hide show
  1. mlm.py +45 -1
mlm.py CHANGED
@@ -525,7 +525,7 @@ class BertEnergyModelForSequenceClassification(BertPreTrainedModel):
525
  def set_output_embeddings(self, new_embeddings):
526
  self.lm_head.decoder = new_embeddings
527
 
528
- def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
529
  outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
530
  logits = self.lm_head(outputs.last_hidden_state)
531
 
@@ -542,4 +542,48 @@ class BertEnergyModelForSequenceClassification(BertPreTrainedModel):
542
  logits=logits,
543
  hidden_states=outputs.hidden_states,
544
  attentions=outputs.attentions,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  )
 
525
  def set_output_embeddings(self, new_embeddings):
526
  self.lm_head.decoder = new_embeddings
527
 
528
+ def forward_mlm(self, input_ids, attention_mask=None, labels=None, **kwargs):
529
  outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
530
  logits = self.lm_head(outputs.last_hidden_state)
531
 
 
542
  logits=logits,
543
  hidden_states=outputs.hidden_states,
544
  attentions=outputs.attentions,
545
+ )
546
+
547
+ def forward(self, input_ids, labels=None, return_dict=None, **kwargs):
548
+ if return_dict is None:
549
+ return_dict = self.return_dict
550
+
551
+ outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
552
+ last_hidden_state = self.norm(outputs.last_hidden_state)
553
+
554
+ x = last_hidden_state[:, 0, :]
555
+ x = self.dropout(x)
556
+ logits = self.classifier(x)
557
+
558
+ loss = None
559
+ if labels is not None:
560
+ labels = labels.to(logits.device)
561
+
562
+ if self.config.problem_type is None:
563
+ if self.num_labels == 1:
564
+ self.config.problem_type = "regression"
565
+ elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int):
566
+ self.config.problem_type = "single_label_classification"
567
+ else:
568
+ self.config.problem_type = "multi_label_classification"
569
+
570
+ if self.config.problem_type == "regression":
571
+ loss_fct = MSELoss()
572
+ loss = loss_fct(logits.squeeze(), labels.squeeze()) if self.num_labels == 1 else loss_fct(logits, labels)
573
+ elif self.config.problem_type == "single_label_classification":
574
+ loss_fct = CrossEntropyLoss()
575
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
576
+ else:
577
+ loss_fct = BCEWithLogitsLoss()
578
+ loss = loss_fct(logits, labels)
579
+
580
+ if not return_dict:
581
+ output = (logits, outputs.hidden_states, outputs.attentions)
582
+ return ((loss,) + output) if loss is not None else output
583
+
584
+ return SequenceClassifierOutput(
585
+ loss=loss,
586
+ logits=logits,
587
+ hidden_states=outputs.hidden_states,
588
+ attentions=outputs.attentions,
589
  )