davda54 commited on
Commit
b05a2b2
·
verified ·
1 Parent(s): 88b5507

Update modeling_gptbert.py

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +0 -37
modeling_gptbert.py CHANGED
@@ -819,13 +819,6 @@ class GptBertForCausalLM(GptBertModel):
819
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
820
  causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
821
 
822
- if not return_dict:
823
- output = (
824
- subword_prediction,
825
- *([contextualized_embeddings] if output_hidden_states else [])
826
- )
827
- return ((causal_lm_loss,) + output) if causal_lm_loss is not None else output
828
-
829
  return CausalLMOutput(
830
  loss=causal_lm_loss,
831
  logits=subword_prediction,
@@ -932,13 +925,6 @@ class GptBertForSequenceClassification(GptBertModel):
932
  loss_fct = nn.BCEWithLogitsLoss()
933
  loss = loss_fct(logits, labels)
934
 
935
- if not return_dict:
936
- output = (
937
- logits,
938
- *([contextualized_embeddings] if output_hidden_states else [])
939
- )
940
- return ((loss,) + output) if loss is not None else output
941
-
942
  return SequenceClassifierOutput(
943
  loss=loss,
944
  logits=logits,
@@ -976,14 +962,6 @@ class GptBertForTokenClassification(GptBertModel):
976
  loss_fct = nn.CrossEntropyLoss()
977
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
978
 
979
- if not return_dict:
980
- output = (
981
- logits,
982
- *([contextualized_embeddings] if output_hidden_states else []),
983
- *([attention_probs] if output_attentions else [])
984
- )
985
- return ((loss,) + output) if loss is not None else output
986
-
987
  return TokenClassifierOutput(
988
  loss=loss,
989
  logits=logits,
@@ -1040,14 +1018,6 @@ class GptBertForQuestionAnswering(GptBertModel):
1040
  end_loss = loss_fct(end_logits, end_positions)
1041
  total_loss = (start_loss + end_loss) / 2
1042
 
1043
- if not return_dict:
1044
- output = (
1045
- start_logits,
1046
- end_logits,
1047
- *([contextualized_embeddings] if output_hidden_states else [])
1048
- )
1049
- return ((total_loss,) + output) if total_loss is not None else output
1050
-
1051
  return QuestionAnsweringModelOutput(
1052
  loss=total_loss,
1053
  start_logits=start_logits,
@@ -1091,13 +1061,6 @@ class GptBertForMultipleChoice(GptBertModel):
1091
  loss_fct = nn.CrossEntropyLoss()
1092
  loss = loss_fct(reshaped_logits, labels)
1093
 
1094
- if not return_dict:
1095
- output = (
1096
- reshaped_logits,
1097
- *([contextualized_embeddings] if output_hidden_states else [])
1098
- )
1099
- return ((loss,) + output) if loss is not None else output
1100
-
1101
  return MultipleChoiceModelOutput(
1102
  loss=loss,
1103
  logits=reshaped_logits,
 
819
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
820
  causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
821
 
 
 
 
 
 
 
 
822
  return CausalLMOutput(
823
  loss=causal_lm_loss,
824
  logits=subword_prediction,
 
925
  loss_fct = nn.BCEWithLogitsLoss()
926
  loss = loss_fct(logits, labels)
927
 
 
 
 
 
 
 
 
928
  return SequenceClassifierOutput(
929
  loss=loss,
930
  logits=logits,
 
962
  loss_fct = nn.CrossEntropyLoss()
963
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
964
 
 
 
 
 
 
 
 
 
965
  return TokenClassifierOutput(
966
  loss=loss,
967
  logits=logits,
 
1018
  end_loss = loss_fct(end_logits, end_positions)
1019
  total_loss = (start_loss + end_loss) / 2
1020
 
 
 
 
 
 
 
 
 
1021
  return QuestionAnsweringModelOutput(
1022
  loss=total_loss,
1023
  start_logits=start_logits,
 
1061
  loss_fct = nn.CrossEntropyLoss()
1062
  loss = loss_fct(reshaped_logits, labels)
1063
 
 
 
 
 
 
 
 
1064
  return MultipleChoiceModelOutput(
1065
  loss=loss,
1066
  logits=reshaped_logits,