Commit
·
d7edcb3
1
Parent(s):
3776ec2
dd
Browse files- generic_ner.py +2 -58
generic_ner.py
CHANGED
|
@@ -712,19 +712,8 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 712 |
|
| 713 |
def _forward(self, inputs):
|
| 714 |
inputs, text_sentences, text = inputs
|
| 715 |
-
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
|
| 716 |
-
self.model.device
|
| 717 |
-
)
|
| 718 |
-
attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
|
| 719 |
-
self.model.device
|
| 720 |
-
)
|
| 721 |
-
# print(f"Let's check the model: {self.model}")
|
| 722 |
-
# check get floret model
|
| 723 |
-
|
| 724 |
-
with torch.no_grad():
|
| 725 |
-
outputs = self.model(input_ids, attention_mask)
|
| 726 |
|
| 727 |
-
return
|
| 728 |
|
| 729 |
def is_within(self, entity1, entity2):
|
| 730 |
"""Check if entity1 is fully within the bounds of entity2."""
|
|
@@ -733,58 +722,13 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 733 |
and entity1["rOffset"] <= entity2["rOffset"]
|
| 734 |
)
|
| 735 |
|
| 736 |
-
def postprocess(self,
|
| 737 |
"""
|
| 738 |
Postprocess the outputs of the model
|
| 739 |
:param outputs:
|
| 740 |
:param kwargs:
|
| 741 |
:return:
|
| 742 |
"""
|
| 743 |
-
tokens_result, text_sentence, text = outputs
|
| 744 |
-
|
| 745 |
-
predictions = {}
|
| 746 |
-
confidence_scores = {}
|
| 747 |
-
for task, logits in tokens_result.logits.items():
|
| 748 |
-
predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
|
| 749 |
-
confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]
|
| 750 |
-
|
| 751 |
-
entities = {}
|
| 752 |
-
for task in predictions.keys():
|
| 753 |
-
words_list, preds_list, confidence_list = realign(
|
| 754 |
-
text_sentence,
|
| 755 |
-
predictions[task],
|
| 756 |
-
confidence_scores[task],
|
| 757 |
-
self.tokenizer,
|
| 758 |
-
self.id2label[task],
|
| 759 |
-
)
|
| 760 |
-
|
| 761 |
-
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
| 762 |
-
|
| 763 |
-
all_entities = []
|
| 764 |
-
coarse_entities = []
|
| 765 |
-
for key in entities:
|
| 766 |
-
if key in ["NE-COARSE-LIT"]:
|
| 767 |
-
coarse_entities = entities[key]
|
| 768 |
-
all_entities.extend(entities[key])
|
| 769 |
-
|
| 770 |
-
if DEBUG:
|
| 771 |
-
print(all_entities)
|
| 772 |
-
# print("After remove_included_entities:")
|
| 773 |
-
all_entities = remove_included_entities(all_entities)
|
| 774 |
-
if DEBUG:
|
| 775 |
-
print("After remove_included_entities:", all_entities)
|
| 776 |
-
all_entities = remove_trailing_stopwords(all_entities)
|
| 777 |
-
if DEBUG:
|
| 778 |
-
print("After remove_trailing_stopwords:", all_entities)
|
| 779 |
-
all_entities = postprocess_entities(all_entities)
|
| 780 |
-
if DEBUG:
|
| 781 |
-
print("After postprocess_entities:", all_entities)
|
| 782 |
-
all_entities = refine_entities_with_coarse(all_entities, coarse_entities)
|
| 783 |
-
if DEBUG:
|
| 784 |
-
print("After refine_entities_with_coarse:", all_entities)
|
| 785 |
-
# print("After attach_comp_to_closest:")
|
| 786 |
-
# pprint(all_entities)
|
| 787 |
-
# print("\n")
|
| 788 |
|
| 789 |
print(f"Let's check the model: {self.model.get_floret_model()}")
|
| 790 |
predictions, probabilities = self.model.get_floret_model().predict([text], k=1)
|
|
|
|
| 712 |
|
| 713 |
def _forward(self, inputs):
|
| 714 |
inputs, text_sentences, text = inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
|
| 716 |
+
return text
|
| 717 |
|
| 718 |
def is_within(self, entity1, entity2):
|
| 719 |
"""Check if entity1 is fully within the bounds of entity2."""
|
|
|
|
| 722 |
and entity1["rOffset"] <= entity2["rOffset"]
|
| 723 |
)
|
| 724 |
|
| 725 |
+
def postprocess(self, text, **kwargs):
|
| 726 |
"""
|
| 727 |
Postprocess the outputs of the model
|
| 728 |
:param outputs:
|
| 729 |
:param kwargs:
|
| 730 |
:return:
|
| 731 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
|
| 733 |
print(f"Let's check the model: {self.model.get_floret_model()}")
|
| 734 |
predictions, probabilities = self.model.get_floret_model().predict([text], k=1)
|