Upload model
Browse files- config.json +1 -0
- modeling_relik.py +65 -57
- pytorch_model.bin +2 -2
config.json
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
{
|
|
|
|
| 2 |
"activation": "gelu",
|
| 3 |
"add_entity_embedding": null,
|
| 4 |
"additional_special_symbols": 101,
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "models/hf_test/hf_test",
|
| 3 |
"activation": "gelu",
|
| 4 |
"add_entity_embedding": null,
|
| 5 |
"additional_special_symbols": 101,
|
modeling_relik.py
CHANGED
|
@@ -32,6 +32,7 @@ class RelikReaderSample:
|
|
| 32 |
self._d[key] = value
|
| 33 |
else:
|
| 34 |
super().__setattr__(key, value)
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
activation2functions = {
|
|
@@ -321,20 +322,40 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
| 321 |
# flattening end predictions
|
| 322 |
# (flattening can happen only if the
|
| 323 |
# end boundaries were not predicted using the gold labels)
|
| 324 |
-
if not self.training:
|
| 325 |
-
flattened_end_predictions = torch.
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
-
# check that the total number of start predictions
|
| 335 |
-
# is equal to the end predictions
|
| 336 |
-
total_start_predictions = sum(map(len, batch_start_predictions))
|
| 337 |
-
total_end_predictions = len(ned_end_predictions)
|
| 338 |
assert (
|
| 339 |
total_start_predictions == 0
|
| 340 |
or total_start_predictions == total_end_predictions
|
|
@@ -342,23 +363,9 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
| 342 |
f"Total number of start predictions = {total_start_predictions}. "
|
| 343 |
f"Total number of end predictions = {total_end_predictions}"
|
| 344 |
)
|
| 345 |
-
|
| 346 |
-
curr_end_pred_num = 0
|
| 347 |
-
for elem_idx, bsp in enumerate(batch_start_predictions):
|
| 348 |
-
for sp in bsp:
|
| 349 |
-
ep = ned_end_predictions[curr_end_pred_num].item()
|
| 350 |
-
if ep < sp:
|
| 351 |
-
ep = sp
|
| 352 |
-
|
| 353 |
-
# if we already set this span throw it (no overlap)
|
| 354 |
-
if flattened_end_predictions[elem_idx, ep] == 1:
|
| 355 |
-
ned_start_predictions[elem_idx, sp] = 0
|
| 356 |
-
else:
|
| 357 |
-
flattened_end_predictions[elem_idx, ep] = 1
|
| 358 |
-
|
| 359 |
-
curr_end_pred_num += 1
|
| 360 |
-
|
| 361 |
ned_end_predictions = flattened_end_predictions
|
|
|
|
|
|
|
| 362 |
|
| 363 |
start_position, end_position = (
|
| 364 |
(start_labels, end_labels)
|
|
@@ -461,7 +468,7 @@ class RelikReaderREModel(PreTrainedModel):
|
|
| 461 |
self.transformer_model.resize_token_embeddings(
|
| 462 |
self.transformer_model.config.vocab_size
|
| 463 |
+ config.additional_special_symbols
|
| 464 |
-
+ config.additional_special_symbols_types
|
| 465 |
)
|
| 466 |
|
| 467 |
# named entity detection layers
|
|
@@ -478,17 +485,21 @@ class RelikReaderREModel(PreTrainedModel):
|
|
| 478 |
)
|
| 479 |
|
| 480 |
if self.config.entity_type_loss and self.config.add_entity_embedding:
|
| 481 |
-
input_hidden_ents = 3 * self.
|
| 482 |
else:
|
| 483 |
-
input_hidden_ents = 2 * self.
|
| 484 |
|
| 485 |
-
self.
|
| 486 |
-
config.activation,
|
|
|
|
|
|
|
|
|
|
| 487 |
)
|
| 488 |
-
|
| 489 |
-
|
|
|
|
|
|
|
| 490 |
)
|
| 491 |
-
self.re_relation_projector = self._get_projection_layer(config.activation)
|
| 492 |
|
| 493 |
if self.config.entity_type_loss or self.relation_disambiguation_loss:
|
| 494 |
self.re_entities_projector = self._get_projection_layer(
|
|
@@ -516,6 +527,7 @@ class RelikReaderREModel(PreTrainedModel):
|
|
| 516 |
self,
|
| 517 |
activation: str,
|
| 518 |
last_hidden: Optional[int] = None,
|
|
|
|
| 519 |
input_hidden=None,
|
| 520 |
layer_norm: bool = True,
|
| 521 |
) -> torch.nn.Sequential:
|
|
@@ -528,12 +540,12 @@ class RelikReaderREModel(PreTrainedModel):
|
|
| 528 |
if input_hidden is None
|
| 529 |
else input_hidden
|
| 530 |
),
|
| 531 |
-
self.config.linears_hidden_size,
|
| 532 |
),
|
| 533 |
activation2functions[activation],
|
| 534 |
torch.nn.Dropout(0.1),
|
| 535 |
torch.nn.Linear(
|
| 536 |
-
self.config.linears_hidden_size,
|
| 537 |
self.config.linears_hidden_size if last_hidden is None else last_hidden,
|
| 538 |
),
|
| 539 |
]
|
|
@@ -635,8 +647,13 @@ class RelikReaderREModel(PreTrainedModel):
|
|
| 635 |
model_entity_features,
|
| 636 |
special_symbols_features,
|
| 637 |
) -> torch.Tensor:
|
| 638 |
-
|
| 639 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
special_symbols_start_representation = self.re_relation_projector(
|
| 641 |
special_symbols_features
|
| 642 |
)
|
|
@@ -720,13 +737,17 @@ class RelikReaderREModel(PreTrainedModel):
|
|
| 720 |
end_labels: Optional[torch.Tensor] = None,
|
| 721 |
disambiguation_labels: Optional[torch.Tensor] = None,
|
| 722 |
relation_labels: Optional[torch.Tensor] = None,
|
| 723 |
-
relation_threshold: float =
|
| 724 |
is_validation: bool = False,
|
| 725 |
is_prediction: bool = False,
|
| 726 |
use_predefined_spans: bool = False,
|
| 727 |
*args,
|
| 728 |
**kwargs,
|
| 729 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
batch_size = input_ids.shape[0]
|
| 731 |
|
| 732 |
model_features = self._get_model_features(
|
|
@@ -898,19 +919,7 @@ class RelikReaderREModel(PreTrainedModel):
|
|
| 898 |
re_probabilities = torch.softmax(re_logits, dim=-1)
|
| 899 |
# we set a thresshold instead of argmax in cause it needs to be tweaked
|
| 900 |
re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
|
| 901 |
-
# re_predictions = re_probabilities.argmax(dim=-1)
|
| 902 |
re_probabilities = re_probabilities[:, :, :, :, 1]
|
| 903 |
-
# re_logits, re_probabilities, re_predictions = (
|
| 904 |
-
# torch.zeros(
|
| 905 |
-
# [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
|
| 906 |
-
# ).to(input_ids.device),
|
| 907 |
-
# torch.zeros(
|
| 908 |
-
# [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
|
| 909 |
-
# ).to(input_ids.device),
|
| 910 |
-
# torch.zeros(
|
| 911 |
-
# [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
|
| 912 |
-
# ).to(input_ids.device),
|
| 913 |
-
# )
|
| 914 |
|
| 915 |
else:
|
| 916 |
(
|
|
@@ -981,10 +990,9 @@ class RelikReaderREModel(PreTrainedModel):
|
|
| 981 |
) / 4
|
| 982 |
output_dict["ned_type_loss"] = ned_type_loss
|
| 983 |
else:
|
| 984 |
-
output_dict["loss"] = ((1 /
|
| 985 |
-
(
|
| 986 |
)
|
| 987 |
-
|
| 988 |
output_dict["ned_start_loss"] = ned_start_loss
|
| 989 |
output_dict["ned_end_loss"] = ned_end_loss
|
| 990 |
output_dict["re_loss"] = relation_loss
|
|
|
|
| 32 |
self._d[key] = value
|
| 33 |
else:
|
| 34 |
super().__setattr__(key, value)
|
| 35 |
+
self._d[key] = value
|
| 36 |
|
| 37 |
|
| 38 |
activation2functions = {
|
|
|
|
| 322 |
# flattening end predictions
|
| 323 |
# (flattening can happen only if the
|
| 324 |
# end boundaries were not predicted using the gold labels)
|
| 325 |
+
if not self.training and ned_end_logits is not None:
|
| 326 |
+
flattened_end_predictions = torch.zeros_like(ned_start_predictions)
|
| 327 |
+
|
| 328 |
+
row_indices, start_positions = torch.where(ned_start_predictions > 0)
|
| 329 |
+
ned_end_predictions[
|
| 330 |
+
ned_end_predictions < start_positions
|
| 331 |
+
] = start_positions[ned_end_predictions < start_positions]
|
| 332 |
+
|
| 333 |
+
end_spans_repeated = (row_indices + 1) * seq_len + ned_end_predictions
|
| 334 |
+
cummax_values, _ = end_spans_repeated.cummax(dim=0)
|
| 335 |
+
|
| 336 |
+
end_spans_repeated = end_spans_repeated > torch.cat(
|
| 337 |
+
(end_spans_repeated[:1], cummax_values[:-1])
|
| 338 |
+
)
|
| 339 |
+
end_spans_repeated[0] = True
|
| 340 |
+
|
| 341 |
+
ned_start_predictions[
|
| 342 |
+
row_indices[~end_spans_repeated],
|
| 343 |
+
start_positions[~end_spans_repeated],
|
| 344 |
+
] = 0
|
| 345 |
+
|
| 346 |
+
row_indices, start_positions, ned_end_predictions = (
|
| 347 |
+
row_indices[end_spans_repeated],
|
| 348 |
+
start_positions[end_spans_repeated],
|
| 349 |
+
ned_end_predictions[end_spans_repeated],
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
flattened_end_predictions[row_indices, ned_end_predictions] = 1
|
| 353 |
+
|
| 354 |
+
total_start_predictions, total_end_predictions = (
|
| 355 |
+
ned_start_predictions.sum(),
|
| 356 |
+
flattened_end_predictions.sum(),
|
| 357 |
+
)
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
assert (
|
| 360 |
total_start_predictions == 0
|
| 361 |
or total_start_predictions == total_end_predictions
|
|
|
|
| 363 |
f"Total number of start predictions = {total_start_predictions}. "
|
| 364 |
f"Total number of end predictions = {total_end_predictions}"
|
| 365 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
ned_end_predictions = flattened_end_predictions
|
| 367 |
+
else:
|
| 368 |
+
ned_end_predictions = torch.zeros_like(ned_start_predictions)
|
| 369 |
|
| 370 |
start_position, end_position = (
|
| 371 |
(start_labels, end_labels)
|
|
|
|
| 468 |
self.transformer_model.resize_token_embeddings(
|
| 469 |
self.transformer_model.config.vocab_size
|
| 470 |
+ config.additional_special_symbols
|
| 471 |
+
+ config.additional_special_symbols_types,
|
| 472 |
)
|
| 473 |
|
| 474 |
# named entity detection layers
|
|
|
|
| 485 |
)
|
| 486 |
|
| 487 |
if self.config.entity_type_loss and self.config.add_entity_embedding:
|
| 488 |
+
input_hidden_ents = 3 * self.config.linears_hidden_size
|
| 489 |
else:
|
| 490 |
+
input_hidden_ents = 2 * self.config.linears_hidden_size
|
| 491 |
|
| 492 |
+
self.re_projector = self._get_projection_layer(
|
| 493 |
+
config.activation,
|
| 494 |
+
input_hidden=2 * self.transformer_model.config.hidden_size,
|
| 495 |
+
hidden=input_hidden_ents,
|
| 496 |
+
last_hidden=2 * self.config.linears_hidden_size,
|
| 497 |
)
|
| 498 |
+
|
| 499 |
+
self.re_relation_projector = self._get_projection_layer(
|
| 500 |
+
config.activation,
|
| 501 |
+
input_hidden=self.transformer_model.config.hidden_size,
|
| 502 |
)
|
|
|
|
| 503 |
|
| 504 |
if self.config.entity_type_loss or self.relation_disambiguation_loss:
|
| 505 |
self.re_entities_projector = self._get_projection_layer(
|
|
|
|
| 527 |
self,
|
| 528 |
activation: str,
|
| 529 |
last_hidden: Optional[int] = None,
|
| 530 |
+
hidden: Optional[int] = None,
|
| 531 |
input_hidden=None,
|
| 532 |
layer_norm: bool = True,
|
| 533 |
) -> torch.nn.Sequential:
|
|
|
|
| 540 |
if input_hidden is None
|
| 541 |
else input_hidden
|
| 542 |
),
|
| 543 |
+
self.config.linears_hidden_size if hidden is None else hidden,
|
| 544 |
),
|
| 545 |
activation2functions[activation],
|
| 546 |
torch.nn.Dropout(0.1),
|
| 547 |
torch.nn.Linear(
|
| 548 |
+
self.config.linears_hidden_size if hidden is None else hidden,
|
| 549 |
self.config.linears_hidden_size if last_hidden is None else last_hidden,
|
| 550 |
),
|
| 551 |
]
|
|
|
|
| 647 |
model_entity_features,
|
| 648 |
special_symbols_features,
|
| 649 |
) -> torch.Tensor:
|
| 650 |
+
model_subject_object_features = self.re_projector(model_entity_features)
|
| 651 |
+
model_subject_features = model_subject_object_features[
|
| 652 |
+
:, :, : model_subject_object_features.shape[-1] // 2
|
| 653 |
+
]
|
| 654 |
+
model_object_features = model_subject_object_features[
|
| 655 |
+
:, :, model_subject_object_features.shape[-1] // 2 :
|
| 656 |
+
]
|
| 657 |
special_symbols_start_representation = self.re_relation_projector(
|
| 658 |
special_symbols_features
|
| 659 |
)
|
|
|
|
| 737 |
end_labels: Optional[torch.Tensor] = None,
|
| 738 |
disambiguation_labels: Optional[torch.Tensor] = None,
|
| 739 |
relation_labels: Optional[torch.Tensor] = None,
|
| 740 |
+
relation_threshold: float = None,
|
| 741 |
is_validation: bool = False,
|
| 742 |
is_prediction: bool = False,
|
| 743 |
use_predefined_spans: bool = False,
|
| 744 |
*args,
|
| 745 |
**kwargs,
|
| 746 |
) -> Dict[str, Any]:
|
| 747 |
+
relation_threshold = (
|
| 748 |
+
self.config.threshold if relation_threshold is None else relation_threshold
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
batch_size = input_ids.shape[0]
|
| 752 |
|
| 753 |
model_features = self._get_model_features(
|
|
|
|
| 919 |
re_probabilities = torch.softmax(re_logits, dim=-1)
|
| 920 |
# we set a thresshold instead of argmax in cause it needs to be tweaked
|
| 921 |
re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
|
|
|
|
| 922 |
re_probabilities = re_probabilities[:, :, :, :, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 923 |
|
| 924 |
else:
|
| 925 |
(
|
|
|
|
| 990 |
) / 4
|
| 991 |
output_dict["ned_type_loss"] = ned_type_loss
|
| 992 |
else:
|
| 993 |
+
output_dict["loss"] = ((1 / 20) * (ned_start_loss + ned_end_loss)) + (
|
| 994 |
+
(9 / 10) * relation_loss
|
| 995 |
)
|
|
|
|
| 996 |
output_dict["ned_start_loss"] = ned_start_loss
|
| 997 |
output_dict["ned_end_loss"] = ned_end_loss
|
| 998 |
output_dict["re_loss"] = relation_loss
|
pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:91ead9afa1a4b1d95a4d8b3997606b937616a928be232b9f13e01aa6cd766473
|
| 3 |
+
size 747280506
|