davda54 commited on
Commit
2d53d20
·
verified ·
1 Parent(s): 0e52af4

fixed NaNs

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +0 -38
modeling_gptbert.py CHANGED
@@ -823,13 +823,6 @@ class GptBertForCausalLM(GptBertModel):
823
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
824
  causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
825
 
826
- if not return_dict:
827
- output = (
828
- subword_prediction,
829
- *([contextualized_embeddings] if output_hidden_states else [])
830
- )
831
- return ((causal_lm_loss,) + output) if masked_lm_loss is not None else output
832
-
833
  return CausalLMOutput(
834
  loss=causal_lm_loss,
835
  logits=subword_prediction,
@@ -936,13 +929,6 @@ class GptBertForSequenceClassification(GptBertModel):
936
  loss_fct = nn.BCEWithLogitsLoss()
937
  loss = loss_fct(logits, labels)
938
 
939
- if not return_dict:
940
- output = (
941
- logits,
942
- *([contextualized_embeddings] if output_hidden_states else [])
943
- )
944
- return ((loss,) + output) if loss is not None else output
945
-
946
  return SequenceClassifierOutput(
947
  loss=loss,
948
  logits=logits,
@@ -980,19 +966,10 @@ class GptBertForTokenClassification(GptBertModel):
980
  loss_fct = nn.CrossEntropyLoss()
981
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
982
 
983
- if not return_dict:
984
- output = (
985
- logits,
986
- *([contextualized_embeddings] if output_hidden_states else []),
987
- *([attention_probs] if output_attentions else [])
988
- )
989
- return ((loss,) + output) if loss is not None else output
990
-
991
  return TokenClassifierOutput(
992
  loss=loss,
993
  logits=logits,
994
  hidden_states=contextualized_embeddings if output_hidden_states else None,
995
- attentions=attention_probs if output_attentions else None
996
  )
997
 
998
 
@@ -1044,14 +1021,6 @@ class GptBertForQuestionAnswering(GptBertModel):
1044
  end_loss = loss_fct(end_logits, end_positions)
1045
  total_loss = (start_loss + end_loss) / 2
1046
 
1047
- if not return_dict:
1048
- output = (
1049
- start_logits,
1050
- end_logits,
1051
- *([contextualized_embeddings] if output_hidden_states else [])
1052
- )
1053
- return ((total_loss,) + output) if total_loss is not None else output
1054
-
1055
  return QuestionAnsweringModelOutput(
1056
  loss=total_loss,
1057
  start_logits=start_logits,
@@ -1095,13 +1064,6 @@ class GptBertForMultipleChoice(GptBertModel):
1095
  loss_fct = nn.CrossEntropyLoss()
1096
  loss = loss_fct(reshaped_logits, labels)
1097
 
1098
- if not return_dict:
1099
- output = (
1100
- reshaped_logits,
1101
- *([contextualized_embeddings] if output_hidden_states else [])
1102
- )
1103
- return ((loss,) + output) if loss is not None else output
1104
-
1105
  return MultipleChoiceModelOutput(
1106
  loss=loss,
1107
  logits=reshaped_logits,
 
823
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
824
  causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
825
 
 
 
 
 
 
 
 
826
  return CausalLMOutput(
827
  loss=causal_lm_loss,
828
  logits=subword_prediction,
 
929
  loss_fct = nn.BCEWithLogitsLoss()
930
  loss = loss_fct(logits, labels)
931
 
 
 
 
 
 
 
 
932
  return SequenceClassifierOutput(
933
  loss=loss,
934
  logits=logits,
 
966
  loss_fct = nn.CrossEntropyLoss()
967
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
968
 
 
 
 
 
 
 
 
 
969
  return TokenClassifierOutput(
970
  loss=loss,
971
  logits=logits,
972
  hidden_states=contextualized_embeddings if output_hidden_states else None,
 
973
  )
974
 
975
 
 
1021
  end_loss = loss_fct(end_logits, end_positions)
1022
  total_loss = (start_loss + end_loss) / 2
1023
 
 
 
 
 
 
 
 
 
1024
  return QuestionAnsweringModelOutput(
1025
  loss=total_loss,
1026
  start_logits=start_logits,
 
1064
  loss_fct = nn.CrossEntropyLoss()
1065
  loss = loss_fct(reshaped_logits, labels)
1066
 
 
 
 
 
 
 
 
1067
  return MultipleChoiceModelOutput(
1068
  loss=loss,
1069
  logits=reshaped_logits,