Update modeling_gptbert.py
Browse files- modeling_gptbert.py +0 -37
modeling_gptbert.py
CHANGED
|
@@ -819,13 +819,6 @@ class GptBertForCausalLM(GptBertModel):
|
|
| 819 |
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
|
| 820 |
causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
|
| 821 |
|
| 822 |
-
if not return_dict:
|
| 823 |
-
output = (
|
| 824 |
-
subword_prediction,
|
| 825 |
-
*([contextualized_embeddings] if output_hidden_states else [])
|
| 826 |
-
)
|
| 827 |
-
return ((causal_lm_loss,) + output) if causal_lm_loss is not None else output
|
| 828 |
-
|
| 829 |
return CausalLMOutput(
|
| 830 |
loss=causal_lm_loss,
|
| 831 |
logits=subword_prediction,
|
|
@@ -932,13 +925,6 @@ class GptBertForSequenceClassification(GptBertModel):
|
|
| 932 |
loss_fct = nn.BCEWithLogitsLoss()
|
| 933 |
loss = loss_fct(logits, labels)
|
| 934 |
|
| 935 |
-
if not return_dict:
|
| 936 |
-
output = (
|
| 937 |
-
logits,
|
| 938 |
-
*([contextualized_embeddings] if output_hidden_states else [])
|
| 939 |
-
)
|
| 940 |
-
return ((loss,) + output) if loss is not None else output
|
| 941 |
-
|
| 942 |
return SequenceClassifierOutput(
|
| 943 |
loss=loss,
|
| 944 |
logits=logits,
|
|
@@ -976,14 +962,6 @@ class GptBertForTokenClassification(GptBertModel):
|
|
| 976 |
loss_fct = nn.CrossEntropyLoss()
|
| 977 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 978 |
|
| 979 |
-
if not return_dict:
|
| 980 |
-
output = (
|
| 981 |
-
logits,
|
| 982 |
-
*([contextualized_embeddings] if output_hidden_states else []),
|
| 983 |
-
*([attention_probs] if output_attentions else [])
|
| 984 |
-
)
|
| 985 |
-
return ((loss,) + output) if loss is not None else output
|
| 986 |
-
|
| 987 |
return TokenClassifierOutput(
|
| 988 |
loss=loss,
|
| 989 |
logits=logits,
|
|
@@ -1040,14 +1018,6 @@ class GptBertForQuestionAnswering(GptBertModel):
|
|
| 1040 |
end_loss = loss_fct(end_logits, end_positions)
|
| 1041 |
total_loss = (start_loss + end_loss) / 2
|
| 1042 |
|
| 1043 |
-
if not return_dict:
|
| 1044 |
-
output = (
|
| 1045 |
-
start_logits,
|
| 1046 |
-
end_logits,
|
| 1047 |
-
*([contextualized_embeddings] if output_hidden_states else [])
|
| 1048 |
-
)
|
| 1049 |
-
return ((total_loss,) + output) if total_loss is not None else output
|
| 1050 |
-
|
| 1051 |
return QuestionAnsweringModelOutput(
|
| 1052 |
loss=total_loss,
|
| 1053 |
start_logits=start_logits,
|
|
@@ -1091,13 +1061,6 @@ class GptBertForMultipleChoice(GptBertModel):
|
|
| 1091 |
loss_fct = nn.CrossEntropyLoss()
|
| 1092 |
loss = loss_fct(reshaped_logits, labels)
|
| 1093 |
|
| 1094 |
-
if not return_dict:
|
| 1095 |
-
output = (
|
| 1096 |
-
reshaped_logits,
|
| 1097 |
-
*([contextualized_embeddings] if output_hidden_states else [])
|
| 1098 |
-
)
|
| 1099 |
-
return ((loss,) + output) if loss is not None else output
|
| 1100 |
-
|
| 1101 |
return MultipleChoiceModelOutput(
|
| 1102 |
loss=loss,
|
| 1103 |
logits=reshaped_logits,
|
|
|
|
| 819 |
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
|
| 820 |
causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
|
| 821 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 822 |
return CausalLMOutput(
|
| 823 |
loss=causal_lm_loss,
|
| 824 |
logits=subword_prediction,
|
|
|
|
| 925 |
loss_fct = nn.BCEWithLogitsLoss()
|
| 926 |
loss = loss_fct(logits, labels)
|
| 927 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 928 |
return SequenceClassifierOutput(
|
| 929 |
loss=loss,
|
| 930 |
logits=logits,
|
|
|
|
| 962 |
loss_fct = nn.CrossEntropyLoss()
|
| 963 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 964 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 965 |
return TokenClassifierOutput(
|
| 966 |
loss=loss,
|
| 967 |
logits=logits,
|
|
|
|
| 1018 |
end_loss = loss_fct(end_logits, end_positions)
|
| 1019 |
total_loss = (start_loss + end_loss) / 2
|
| 1020 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1021 |
return QuestionAnsweringModelOutput(
|
| 1022 |
loss=total_loss,
|
| 1023 |
start_logits=start_logits,
|
|
|
|
| 1061 |
loss_fct = nn.CrossEntropyLoss()
|
| 1062 |
loss = loss_fct(reshaped_logits, labels)
|
| 1063 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1064 |
return MultipleChoiceModelOutput(
|
| 1065 |
loss=loss,
|
| 1066 |
logits=reshaped_logits,
|