fixed NaNs
Browse files- modeling_gptbert.py +0 -38
modeling_gptbert.py
CHANGED
|
@@ -823,13 +823,6 @@ class GptBertForCausalLM(GptBertModel):
|
|
| 823 |
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
|
| 824 |
causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
|
| 825 |
|
| 826 |
-
if not return_dict:
|
| 827 |
-
output = (
|
| 828 |
-
subword_prediction,
|
| 829 |
-
*([contextualized_embeddings] if output_hidden_states else [])
|
| 830 |
-
)
|
| 831 |
-
return ((causal_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 832 |
-
|
| 833 |
return CausalLMOutput(
|
| 834 |
loss=causal_lm_loss,
|
| 835 |
logits=subword_prediction,
|
|
@@ -936,13 +929,6 @@ class GptBertForSequenceClassification(GptBertModel):
|
|
| 936 |
loss_fct = nn.BCEWithLogitsLoss()
|
| 937 |
loss = loss_fct(logits, labels)
|
| 938 |
|
| 939 |
-
if not return_dict:
|
| 940 |
-
output = (
|
| 941 |
-
logits,
|
| 942 |
-
*([contextualized_embeddings] if output_hidden_states else [])
|
| 943 |
-
)
|
| 944 |
-
return ((loss,) + output) if loss is not None else output
|
| 945 |
-
|
| 946 |
return SequenceClassifierOutput(
|
| 947 |
loss=loss,
|
| 948 |
logits=logits,
|
|
@@ -980,19 +966,10 @@ class GptBertForTokenClassification(GptBertModel):
|
|
| 980 |
loss_fct = nn.CrossEntropyLoss()
|
| 981 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 982 |
|
| 983 |
-
if not return_dict:
|
| 984 |
-
output = (
|
| 985 |
-
logits,
|
| 986 |
-
*([contextualized_embeddings] if output_hidden_states else []),
|
| 987 |
-
*([attention_probs] if output_attentions else [])
|
| 988 |
-
)
|
| 989 |
-
return ((loss,) + output) if loss is not None else output
|
| 990 |
-
|
| 991 |
return TokenClassifierOutput(
|
| 992 |
loss=loss,
|
| 993 |
logits=logits,
|
| 994 |
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
| 995 |
-
attentions=attention_probs if output_attentions else None
|
| 996 |
)
|
| 997 |
|
| 998 |
|
|
@@ -1044,14 +1021,6 @@ class GptBertForQuestionAnswering(GptBertModel):
|
|
| 1044 |
end_loss = loss_fct(end_logits, end_positions)
|
| 1045 |
total_loss = (start_loss + end_loss) / 2
|
| 1046 |
|
| 1047 |
-
if not return_dict:
|
| 1048 |
-
output = (
|
| 1049 |
-
start_logits,
|
| 1050 |
-
end_logits,
|
| 1051 |
-
*([contextualized_embeddings] if output_hidden_states else [])
|
| 1052 |
-
)
|
| 1053 |
-
return ((total_loss,) + output) if total_loss is not None else output
|
| 1054 |
-
|
| 1055 |
return QuestionAnsweringModelOutput(
|
| 1056 |
loss=total_loss,
|
| 1057 |
start_logits=start_logits,
|
|
@@ -1095,13 +1064,6 @@ class GptBertForMultipleChoice(GptBertModel):
|
|
| 1095 |
loss_fct = nn.CrossEntropyLoss()
|
| 1096 |
loss = loss_fct(reshaped_logits, labels)
|
| 1097 |
|
| 1098 |
-
if not return_dict:
|
| 1099 |
-
output = (
|
| 1100 |
-
reshaped_logits,
|
| 1101 |
-
*([contextualized_embeddings] if output_hidden_states else [])
|
| 1102 |
-
)
|
| 1103 |
-
return ((loss,) + output) if loss is not None else output
|
| 1104 |
-
|
| 1105 |
return MultipleChoiceModelOutput(
|
| 1106 |
loss=loss,
|
| 1107 |
logits=reshaped_logits,
|
|
|
|
| 823 |
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
|
| 824 |
causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
|
| 825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
return CausalLMOutput(
|
| 827 |
loss=causal_lm_loss,
|
| 828 |
logits=subword_prediction,
|
|
|
|
| 929 |
loss_fct = nn.BCEWithLogitsLoss()
|
| 930 |
loss = loss_fct(logits, labels)
|
| 931 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 932 |
return SequenceClassifierOutput(
|
| 933 |
loss=loss,
|
| 934 |
logits=logits,
|
|
|
|
| 966 |
loss_fct = nn.CrossEntropyLoss()
|
| 967 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 968 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 969 |
return TokenClassifierOutput(
|
| 970 |
loss=loss,
|
| 971 |
logits=logits,
|
| 972 |
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
|
|
|
| 973 |
)
|
| 974 |
|
| 975 |
|
|
|
|
| 1021 |
end_loss = loss_fct(end_logits, end_positions)
|
| 1022 |
total_loss = (start_loss + end_loss) / 2
|
| 1023 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1024 |
return QuestionAnsweringModelOutput(
|
| 1025 |
loss=total_loss,
|
| 1026 |
start_logits=start_logits,
|
|
|
|
| 1064 |
loss_fct = nn.CrossEntropyLoss()
|
| 1065 |
loss = loss_fct(reshaped_logits, labels)
|
| 1066 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1067 |
return MultipleChoiceModelOutput(
|
| 1068 |
loss=loss,
|
| 1069 |
logits=reshaped_logits,
|