Update modeling_deberta.py
Browse files- modeling_deberta.py +36 -55
modeling_deberta.py
CHANGED
|
@@ -1058,8 +1058,6 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
|
|
| 1058 |
)
|
| 1059 |
encoded_layers = list(encoder_outputs[1])
|
| 1060 |
|
| 1061 |
-
# print(self.z_steps)
|
| 1062 |
-
|
| 1063 |
if self.z_steps > 0:
|
| 1064 |
hidden_states = encoded_layers[-2]
|
| 1065 |
layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
|
|
@@ -1100,8 +1098,6 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
|
| 1100 |
self.deberta = DebertaV2Model(config)
|
| 1101 |
self.cls = DebertaV2OnlyMLMHead(config)
|
| 1102 |
|
| 1103 |
-
self.verbose = False
|
| 1104 |
-
|
| 1105 |
# Initialize weights and apply final processing
|
| 1106 |
self.post_init()
|
| 1107 |
|
|
@@ -1132,19 +1128,6 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
|
| 1132 |
|
| 1133 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1134 |
|
| 1135 |
-
if self.verbose:
|
| 1136 |
-
for i in input_ids[0, :].tolist():
|
| 1137 |
-
print(i, end=", ")
|
| 1138 |
-
print()
|
| 1139 |
-
if attention_mask is not None:
|
| 1140 |
-
for i in attention_mask[0, :].tolist():
|
| 1141 |
-
print(i, end=", ")
|
| 1142 |
-
print()
|
| 1143 |
-
if position_ids is not None:
|
| 1144 |
-
for i in position_ids[0, :].tolist():
|
| 1145 |
-
print(i, end=", ")
|
| 1146 |
-
print()
|
| 1147 |
-
|
| 1148 |
outputs = self.deberta(
|
| 1149 |
input_ids,
|
| 1150 |
attention_mask=attention_mask,
|
|
@@ -1183,6 +1166,7 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
| 1183 |
super().__init__(config)
|
| 1184 |
config.is_decoder = True
|
| 1185 |
self.mask_token_id = config.mask_token_id
|
|
|
|
| 1186 |
self.sep_token_id = config.sep_token_id
|
| 1187 |
self.n_masks = 3
|
| 1188 |
|
|
@@ -1200,12 +1184,39 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
| 1200 |
):
|
| 1201 |
position_ids = kwargs.get("position_ids", None)
|
| 1202 |
|
| 1203 |
-
|
| 1204 |
-
|
| 1205 |
-
|
| 1206 |
-
|
| 1207 |
-
|
| 1208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1209 |
|
| 1210 |
# Omit tokens covered by past_key_values
|
| 1211 |
if past_key_values is not None:
|
|
@@ -1228,7 +1239,7 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
| 1228 |
{
|
| 1229 |
"position_ids": position_ids,
|
| 1230 |
"past_key_values": past_key_values,
|
| 1231 |
-
"use_cache":
|
| 1232 |
"attention_mask": attention_mask,
|
| 1233 |
}
|
| 1234 |
)
|
|
@@ -1255,36 +1266,6 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
| 1255 |
assert past_key_values is None, "past_key_values is not supported for now"
|
| 1256 |
assert use_cache is None, "use_cache is not supported for now"
|
| 1257 |
|
| 1258 |
-
assert input_ids[0, -1] != self.sep_token_id, "remove the last token if it is a sep token"
|
| 1259 |
-
|
| 1260 |
-
batch_size, seq_length = input_ids.shape
|
| 1261 |
-
input_ids = torch.cat(
|
| 1262 |
-
[
|
| 1263 |
-
input_ids,
|
| 1264 |
-
torch.full((batch_size, self.n_masks), self.mask_token_id, device=input_ids.device),
|
| 1265 |
-
torch.full((batch_size, 1), self.sep_token_id, device=input_ids.device)
|
| 1266 |
-
],
|
| 1267 |
-
dim=-1
|
| 1268 |
-
)
|
| 1269 |
-
|
| 1270 |
-
if attention_mask is not None:
|
| 1271 |
-
attention_mask = torch.cat(
|
| 1272 |
-
[
|
| 1273 |
-
attention_mask,
|
| 1274 |
-
torch.full((batch_size, self.n_masks + 1), attention_mask[0, -1], device=attention_mask.device),
|
| 1275 |
-
],
|
| 1276 |
-
dim=-1
|
| 1277 |
-
)
|
| 1278 |
-
|
| 1279 |
-
if position_ids is not None:
|
| 1280 |
-
position_ids = torch.cat(
|
| 1281 |
-
[
|
| 1282 |
-
position_ids,
|
| 1283 |
-
torch.arange(0, self.n_masks + 1, device=position_ids.device).unsqueeze(0) + position_ids[:, -1:],
|
| 1284 |
-
],
|
| 1285 |
-
dim=-1
|
| 1286 |
-
)
|
| 1287 |
-
|
| 1288 |
outputs = super().forward(
|
| 1289 |
input_ids,
|
| 1290 |
attention_mask=attention_mask,
|
|
@@ -1297,7 +1278,7 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
| 1297 |
)
|
| 1298 |
|
| 1299 |
# shift the outputs and skip excess masks
|
| 1300 |
-
logits = outputs.logits[:,
|
| 1301 |
|
| 1302 |
loss = None
|
| 1303 |
if labels is not None:
|
|
|
|
| 1058 |
)
|
| 1059 |
encoded_layers = list(encoder_outputs[1])
|
| 1060 |
|
|
|
|
|
|
|
| 1061 |
if self.z_steps > 0:
|
| 1062 |
hidden_states = encoded_layers[-2]
|
| 1063 |
layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
|
|
|
|
| 1098 |
self.deberta = DebertaV2Model(config)
|
| 1099 |
self.cls = DebertaV2OnlyMLMHead(config)
|
| 1100 |
|
|
|
|
|
|
|
| 1101 |
# Initialize weights and apply final processing
|
| 1102 |
self.post_init()
|
| 1103 |
|
|
|
|
| 1128 |
|
| 1129 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1131 |
outputs = self.deberta(
|
| 1132 |
input_ids,
|
| 1133 |
attention_mask=attention_mask,
|
|
|
|
| 1166 |
super().__init__(config)
|
| 1167 |
config.is_decoder = True
|
| 1168 |
self.mask_token_id = config.mask_token_id
|
| 1169 |
+
self.cls_token_id = config.cls_token_id
|
| 1170 |
self.sep_token_id = config.sep_token_id
|
| 1171 |
self.n_masks = 3
|
| 1172 |
|
|
|
|
| 1184 |
):
|
| 1185 |
position_ids = kwargs.get("position_ids", None)
|
| 1186 |
|
| 1187 |
+
assert input_ids[0, 0] != self.cls_token_id, "`add_special_tokens` should be set to `False`, but `[CLS]` token was detected"
|
| 1188 |
+
assert input_ids[0, -1] != self.sep_token_id, "`add_special_tokens` should be set to `False`, but `[SEP]` token was detected"
|
| 1189 |
+
|
| 1190 |
+
batch_size, seq_length = input_ids.shape
|
| 1191 |
+
input_ids = torch.cat(
|
| 1192 |
+
[
|
| 1193 |
+
torch.full((batch_size, 1), self.cls_token_id, device=input_ids.device)
|
| 1194 |
+
input_ids,
|
| 1195 |
+
torch.full((batch_size, self.n_masks), self.mask_token_id, device=input_ids.device),
|
| 1196 |
+
torch.full((batch_size, 1), self.sep_token_id, device=input_ids.device)
|
| 1197 |
+
],
|
| 1198 |
+
dim=-1
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
if attention_mask is not None:
|
| 1202 |
+
attention_mask = torch.cat(
|
| 1203 |
+
[
|
| 1204 |
+
torch.full((batch_size, 1), attention_mask[0, 0], device=attention_mask.device),
|
| 1205 |
+
attention_mask,
|
| 1206 |
+
torch.full((batch_size, self.n_masks + 1), attention_mask[0, -1], device=attention_mask.device),
|
| 1207 |
+
],
|
| 1208 |
+
dim=-1
|
| 1209 |
+
)
|
| 1210 |
+
|
| 1211 |
+
if position_ids is not None:
|
| 1212 |
+
position_ids = torch.cat(
|
| 1213 |
+
[
|
| 1214 |
+
torch.zeros(batch_size, 1, device=position_ids.device),
|
| 1215 |
+
position_ids + 1,
|
| 1216 |
+
torch.arange(0, self.n_masks + 1, device=position_ids.device).unsqueeze(0) + position_ids[:, -1:] + 1,
|
| 1217 |
+
],
|
| 1218 |
+
dim=-1
|
| 1219 |
+
)
|
| 1220 |
|
| 1221 |
# Omit tokens covered by past_key_values
|
| 1222 |
if past_key_values is not None:
|
|
|
|
| 1239 |
{
|
| 1240 |
"position_ids": position_ids,
|
| 1241 |
"past_key_values": past_key_values,
|
| 1242 |
+
"use_cache": None,
|
| 1243 |
"attention_mask": attention_mask,
|
| 1244 |
}
|
| 1245 |
)
|
|
|
|
| 1266 |
assert past_key_values is None, "past_key_values is not supported for now"
|
| 1267 |
assert use_cache is None, "use_cache is not supported for now"
|
| 1268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1269 |
outputs = super().forward(
|
| 1270 |
input_ids,
|
| 1271 |
attention_mask=attention_mask,
|
|
|
|
| 1278 |
)
|
| 1279 |
|
| 1280 |
# shift the outputs and skip excess masks
|
| 1281 |
+
logits = outputs.logits[:, 2:-self.n_masks, :].contiguous()
|
| 1282 |
|
| 1283 |
loss = None
|
| 1284 |
if labels is not None:
|