“SufurElite”
commited on
Commit
·
0bdc170
1
Parent(s):
eff5003
added unzipped gz predictions, the checkpoint with values, and the tree output possibility in the model
Browse files- ELC_ParserBERT_10M_textonly_predictions.json +0 -0
- checkpoint/checkpoint.bin +3 -0
- modeling_ltgbert.py +141 -48
ELC_ParserBERT_10M_textonly_predictions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
checkpoint/checkpoint.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:797ed09acace61d3d397b9d425d28c272f2d6ce8bda8161f6a865fda491526f4
|
| 3 |
+
size 427253662
|
modeling_ltgbert.py
CHANGED
|
@@ -374,7 +374,7 @@ class ParserNetwork(nn.Module):
|
|
| 374 |
|
| 375 |
distance, height = self.parse(x, embeddings)
|
| 376 |
att_mask, cibling, head, block = self.generate_mask(x, distance, height)
|
| 377 |
-
return att_mask, cibling, head, block
|
| 378 |
|
| 379 |
|
| 380 |
class Encoder(nn.Module):
|
|
@@ -790,10 +790,11 @@ LTG_BERT_INPUTS_DOCSTRING = r"""
|
|
| 790 |
LTG_BERT_START_DOCSTRING,
|
| 791 |
)
|
| 792 |
class LtgBertModel(LtgBertPreTrainedModel):
|
| 793 |
-
def __init__(self, config, add_mlm_layer=False):
|
| 794 |
super().__init__(config)
|
| 795 |
self.config = config
|
| 796 |
-
|
|
|
|
| 797 |
self.embedding = Embedding(config)
|
| 798 |
self.parser_network = ParserNetwork(config, pad=config.pad_token_id)
|
| 799 |
self.transformer = Encoder(config, activation_checkpointing=False)
|
|
@@ -823,7 +824,7 @@ class LtgBertModel(LtgBertPreTrainedModel):
|
|
| 823 |
device = input_ids.device
|
| 824 |
|
| 825 |
static_embeddings, relative_embedding = self.embedding(input_ids.t())
|
| 826 |
-
att_mask, cibling, head, block = self.parser_network(
|
| 827 |
input_ids.t(), static_embeddings
|
| 828 |
)
|
| 829 |
contextualized_embeddings, attention_probs = self.transformer(
|
|
@@ -837,6 +838,9 @@ class LtgBertModel(LtgBertPreTrainedModel):
|
|
| 837 |
contextualized_embeddings[i] - contextualized_embeddings[i - 1]
|
| 838 |
for i in range(1, len(contextualized_embeddings))
|
| 839 |
]
|
|
|
|
|
|
|
|
|
|
| 840 |
return last_layer, contextualized_embeddings, attention_probs
|
| 841 |
|
| 842 |
@add_start_docstrings_to_model_forward(
|
|
@@ -863,13 +867,28 @@ class LtgBertModel(LtgBertPreTrainedModel):
|
|
| 863 |
return_dict = (
|
| 864 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 865 |
)
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
if not return_dict:
|
| 874 |
return (
|
| 875 |
sequence_output,
|
|
@@ -891,8 +910,8 @@ class LtgBertModel(LtgBertPreTrainedModel):
|
|
| 891 |
class LtgBertForMaskedLM(LtgBertModel):
|
| 892 |
_keys_to_ignore_on_load_unexpected = ["head"]
|
| 893 |
|
| 894 |
-
def __init__(self, config):
|
| 895 |
-
super().__init__(config, add_mlm_layer=True)
|
| 896 |
|
| 897 |
def get_output_embeddings(self):
|
| 898 |
return self.classifier.nonlinearity[-1].weight
|
|
@@ -921,12 +940,20 @@ class LtgBertForMaskedLM(LtgBertModel):
|
|
| 921 |
return_dict = (
|
| 922 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 923 |
)
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
subword_prediction = self.classifier(sequence_output)
|
| 931 |
|
| 932 |
masked_lm_loss = None
|
|
@@ -934,7 +961,13 @@ class LtgBertForMaskedLM(LtgBertModel):
|
|
| 934 |
masked_lm_loss = F.cross_entropy(
|
| 935 |
subword_prediction.flatten(0, 1), labels.flatten()
|
| 936 |
)
|
| 937 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
if not return_dict:
|
| 939 |
output = (
|
| 940 |
subword_prediction,
|
|
@@ -1027,12 +1060,20 @@ class LtgBertForSequenceClassification(LtgBertModel):
|
|
| 1027 |
return_dict = (
|
| 1028 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1029 |
)
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1036 |
logits = self.head(sequence_output[:, 0, :])
|
| 1037 |
|
| 1038 |
loss = None
|
|
@@ -1059,7 +1100,14 @@ class LtgBertForSequenceClassification(LtgBertModel):
|
|
| 1059 |
elif self.config.problem_type == "multi_label_classification":
|
| 1060 |
loss_fct = nn.BCEWithLogitsLoss()
|
| 1061 |
loss = loss_fct(logits, labels)
|
| 1062 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1063 |
if not return_dict:
|
| 1064 |
output = (
|
| 1065 |
logits,
|
|
@@ -1110,19 +1158,34 @@ class LtgBertForTokenClassification(LtgBertModel):
|
|
| 1110 |
return_dict = (
|
| 1111 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1112 |
)
|
| 1113 |
-
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1119 |
logits = self.head(sequence_output)
|
| 1120 |
|
| 1121 |
loss = None
|
| 1122 |
if labels is not None:
|
| 1123 |
loss_fct = nn.CrossEntropyLoss()
|
| 1124 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1126 |
if not return_dict:
|
| 1127 |
output = (
|
| 1128 |
logits,
|
|
@@ -1174,12 +1237,20 @@ class LtgBertForQuestionAnswering(LtgBertModel):
|
|
| 1174 |
return_dict = (
|
| 1175 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1176 |
)
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
| 1181 |
-
|
| 1182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1183 |
logits = self.head(sequence_output)
|
| 1184 |
|
| 1185 |
start_logits, end_logits = logits.split(1, dim=-1)
|
|
@@ -1203,7 +1274,14 @@ class LtgBertForQuestionAnswering(LtgBertModel):
|
|
| 1203 |
start_loss = loss_fct(start_logits, start_positions)
|
| 1204 |
end_loss = loss_fct(end_logits, end_positions)
|
| 1205 |
total_loss = (start_loss + end_loss) / 2
|
| 1206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1207 |
if not return_dict:
|
| 1208 |
output = (
|
| 1209 |
start_logits,
|
|
@@ -1264,12 +1342,20 @@ class LtgBertForMultipleChoice(LtgBertModel):
|
|
| 1264 |
if attention_mask is not None
|
| 1265 |
else None
|
| 1266 |
)
|
| 1267 |
-
|
| 1268 |
-
|
| 1269 |
-
|
| 1270 |
-
|
| 1271 |
-
|
| 1272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1273 |
logits = self.head(sequence_output)
|
| 1274 |
reshaped_logits = logits.view(-1, num_choices)
|
| 1275 |
|
|
@@ -1277,7 +1363,14 @@ class LtgBertForMultipleChoice(LtgBertModel):
|
|
| 1277 |
if labels is not None:
|
| 1278 |
loss_fct = nn.CrossEntropyLoss()
|
| 1279 |
loss = loss_fct(reshaped_logits, labels)
|
| 1280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1281 |
if not return_dict:
|
| 1282 |
output = (
|
| 1283 |
reshaped_logits,
|
|
|
|
| 374 |
|
| 375 |
distance, height = self.parse(x, embeddings)
|
| 376 |
att_mask, cibling, head, block = self.generate_mask(x, distance, height)
|
| 377 |
+
return att_mask, cibling, head, block, distance, height
|
| 378 |
|
| 379 |
|
| 380 |
class Encoder(nn.Module):
|
|
|
|
| 790 |
LTG_BERT_START_DOCSTRING,
|
| 791 |
)
|
| 792 |
class LtgBertModel(LtgBertPreTrainedModel):
|
| 793 |
+
def __init__(self, config, add_mlm_layer=False, tree_output=False):
|
| 794 |
super().__init__(config)
|
| 795 |
self.config = config
|
| 796 |
+
self.tree_output=tree_output
|
| 797 |
+
|
| 798 |
self.embedding = Embedding(config)
|
| 799 |
self.parser_network = ParserNetwork(config, pad=config.pad_token_id)
|
| 800 |
self.transformer = Encoder(config, activation_checkpointing=False)
|
|
|
|
| 824 |
device = input_ids.device
|
| 825 |
|
| 826 |
static_embeddings, relative_embedding = self.embedding(input_ids.t())
|
| 827 |
+
att_mask, cibling, head, block, distance, height = self.parser_network(
|
| 828 |
input_ids.t(), static_embeddings
|
| 829 |
)
|
| 830 |
contextualized_embeddings, attention_probs = self.transformer(
|
|
|
|
| 838 |
contextualized_embeddings[i] - contextualized_embeddings[i - 1]
|
| 839 |
for i in range(1, len(contextualized_embeddings))
|
| 840 |
]
|
| 841 |
+
if self.tree_output:
|
| 842 |
+
return last_layer, contextualized_embeddings, attention_probs, {'distance': distance, 'height': height,
|
| 843 |
+
'cibling': cibling, 'head': head, 'block': block}
|
| 844 |
return last_layer, contextualized_embeddings, attention_probs
|
| 845 |
|
| 846 |
@add_start_docstrings_to_model_forward(
|
|
|
|
| 867 |
return_dict = (
|
| 868 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 869 |
)
|
| 870 |
+
tree_values = {} if self.tree_output else None
|
| 871 |
+
if self.tree_output:
|
| 872 |
+
(
|
| 873 |
+
sequence_output,
|
| 874 |
+
contextualized_embeddings,
|
| 875 |
+
attention_probs,
|
| 876 |
+
tree_values
|
| 877 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 878 |
+
else:
|
| 879 |
+
(
|
| 880 |
+
sequence_output,
|
| 881 |
+
contextualized_embeddings,
|
| 882 |
+
attention_probs
|
| 883 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 884 |
+
|
| 885 |
+
if self.tree_output:
|
| 886 |
+
return (
|
| 887 |
+
sequence_output,
|
| 888 |
+
tree_values,
|
| 889 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 890 |
+
*([attention_probs] if output_attentions else []),
|
| 891 |
+
)
|
| 892 |
if not return_dict:
|
| 893 |
return (
|
| 894 |
sequence_output,
|
|
|
|
| 910 |
class LtgBertForMaskedLM(LtgBertModel):
|
| 911 |
_keys_to_ignore_on_load_unexpected = ["head"]
|
| 912 |
|
| 913 |
+
def __init__(self, config, tree_output=False):
|
| 914 |
+
super().__init__(config, add_mlm_layer=True, tree_output=tree_output)
|
| 915 |
|
| 916 |
def get_output_embeddings(self):
|
| 917 |
return self.classifier.nonlinearity[-1].weight
|
|
|
|
| 940 |
return_dict = (
|
| 941 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 942 |
)
|
| 943 |
+
tree_values = {} if self.tree_output else None
|
| 944 |
+
if self.tree_output:
|
| 945 |
+
(
|
| 946 |
+
sequence_output,
|
| 947 |
+
contextualized_embeddings,
|
| 948 |
+
attention_probs,
|
| 949 |
+
tree_values
|
| 950 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 951 |
+
else:
|
| 952 |
+
(
|
| 953 |
+
sequence_output,
|
| 954 |
+
contextualized_embeddings,
|
| 955 |
+
attention_probs
|
| 956 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 957 |
subword_prediction = self.classifier(sequence_output)
|
| 958 |
|
| 959 |
masked_lm_loss = None
|
|
|
|
| 961 |
masked_lm_loss = F.cross_entropy(
|
| 962 |
subword_prediction.flatten(0, 1), labels.flatten()
|
| 963 |
)
|
| 964 |
+
if self.tree_output:
|
| 965 |
+
return (
|
| 966 |
+
sequence_output,
|
| 967 |
+
tree_values,
|
| 968 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 969 |
+
*([attention_probs] if output_attentions else []),
|
| 970 |
+
)
|
| 971 |
if not return_dict:
|
| 972 |
output = (
|
| 973 |
subword_prediction,
|
|
|
|
| 1060 |
return_dict = (
|
| 1061 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1062 |
)
|
| 1063 |
+
tree_values = {} if self.tree_output else None
|
| 1064 |
+
if self.tree_output:
|
| 1065 |
+
(
|
| 1066 |
+
sequence_output,
|
| 1067 |
+
contextualized_embeddings,
|
| 1068 |
+
attention_probs,
|
| 1069 |
+
tree_values
|
| 1070 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1071 |
+
else:
|
| 1072 |
+
(
|
| 1073 |
+
sequence_output,
|
| 1074 |
+
contextualized_embeddings,
|
| 1075 |
+
attention_probs
|
| 1076 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1077 |
logits = self.head(sequence_output[:, 0, :])
|
| 1078 |
|
| 1079 |
loss = None
|
|
|
|
| 1100 |
elif self.config.problem_type == "multi_label_classification":
|
| 1101 |
loss_fct = nn.BCEWithLogitsLoss()
|
| 1102 |
loss = loss_fct(logits, labels)
|
| 1103 |
+
|
| 1104 |
+
if self.tree_output:
|
| 1105 |
+
return (
|
| 1106 |
+
sequence_output,
|
| 1107 |
+
tree_values,
|
| 1108 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 1109 |
+
*([attention_probs] if output_attentions else []),
|
| 1110 |
+
)
|
| 1111 |
if not return_dict:
|
| 1112 |
output = (
|
| 1113 |
logits,
|
|
|
|
| 1158 |
return_dict = (
|
| 1159 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1160 |
)
|
| 1161 |
+
tree_values = {} if self.tree_output else None
|
| 1162 |
+
if self.tree_output:
|
| 1163 |
+
(
|
| 1164 |
+
sequence_output,
|
| 1165 |
+
contextualized_embeddings,
|
| 1166 |
+
attention_probs,
|
| 1167 |
+
tree_values
|
| 1168 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1169 |
+
else:
|
| 1170 |
+
(
|
| 1171 |
+
sequence_output,
|
| 1172 |
+
contextualized_embeddings,
|
| 1173 |
+
attention_probs
|
| 1174 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1175 |
logits = self.head(sequence_output)
|
| 1176 |
|
| 1177 |
loss = None
|
| 1178 |
if labels is not None:
|
| 1179 |
loss_fct = nn.CrossEntropyLoss()
|
| 1180 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1181 |
+
|
| 1182 |
+
if self.tree_output:
|
| 1183 |
+
return (
|
| 1184 |
+
sequence_output,
|
| 1185 |
+
tree_values,
|
| 1186 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 1187 |
+
*([attention_probs] if output_attentions else []),
|
| 1188 |
+
)
|
| 1189 |
if not return_dict:
|
| 1190 |
output = (
|
| 1191 |
logits,
|
|
|
|
| 1237 |
return_dict = (
|
| 1238 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1239 |
)
|
| 1240 |
+
tree_values = {} if self.tree_output else None
|
| 1241 |
+
if self.tree_output:
|
| 1242 |
+
(
|
| 1243 |
+
sequence_output,
|
| 1244 |
+
contextualized_embeddings,
|
| 1245 |
+
attention_probs,
|
| 1246 |
+
tree_values
|
| 1247 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1248 |
+
else:
|
| 1249 |
+
(
|
| 1250 |
+
sequence_output,
|
| 1251 |
+
contextualized_embeddings,
|
| 1252 |
+
attention_probs
|
| 1253 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1254 |
logits = self.head(sequence_output)
|
| 1255 |
|
| 1256 |
start_logits, end_logits = logits.split(1, dim=-1)
|
|
|
|
| 1274 |
start_loss = loss_fct(start_logits, start_positions)
|
| 1275 |
end_loss = loss_fct(end_logits, end_positions)
|
| 1276 |
total_loss = (start_loss + end_loss) / 2
|
| 1277 |
+
|
| 1278 |
+
if self.tree_output:
|
| 1279 |
+
return (
|
| 1280 |
+
sequence_output,
|
| 1281 |
+
tree_values,
|
| 1282 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 1283 |
+
*([attention_probs] if output_attentions else []),
|
| 1284 |
+
)
|
| 1285 |
if not return_dict:
|
| 1286 |
output = (
|
| 1287 |
start_logits,
|
|
|
|
| 1342 |
if attention_mask is not None
|
| 1343 |
else None
|
| 1344 |
)
|
| 1345 |
+
tree_values = {} if self.tree_output else None
|
| 1346 |
+
if self.tree_output:
|
| 1347 |
+
(
|
| 1348 |
+
sequence_output,
|
| 1349 |
+
contextualized_embeddings,
|
| 1350 |
+
attention_probs,
|
| 1351 |
+
tree_values
|
| 1352 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1353 |
+
else:
|
| 1354 |
+
(
|
| 1355 |
+
sequence_output,
|
| 1356 |
+
contextualized_embeddings,
|
| 1357 |
+
attention_probs
|
| 1358 |
+
) = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1359 |
logits = self.head(sequence_output)
|
| 1360 |
reshaped_logits = logits.view(-1, num_choices)
|
| 1361 |
|
|
|
|
| 1363 |
if labels is not None:
|
| 1364 |
loss_fct = nn.CrossEntropyLoss()
|
| 1365 |
loss = loss_fct(reshaped_logits, labels)
|
| 1366 |
+
|
| 1367 |
+
if self.tree_output:
|
| 1368 |
+
return (
|
| 1369 |
+
sequence_output,
|
| 1370 |
+
tree_values,
|
| 1371 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 1372 |
+
*([attention_probs] if output_attentions else []),
|
| 1373 |
+
)
|
| 1374 |
if not return_dict:
|
| 1375 |
output = (
|
| 1376 |
reshaped_logits,
|