Commit ·
6deb831
1
Parent(s): ca874c3
move to pregenerated tokens - some bug with word ids -- move to the inital ones
Browse files- generic_ner.py +9 -6
generic_ner.py
CHANGED
|
@@ -262,10 +262,11 @@ def get_entities(tokens, tags, confidences, text):
|
|
| 262 |
|
| 263 |
|
| 264 |
def realign(
|
| 265 |
-
tokens, out_label_preds, softmax_scores, tokenizer, reverted_label_map
|
| 266 |
):
|
| 267 |
preds_list, words_list, confidence_list = [], [], []
|
| 268 |
-
word_ids = tokenizer(tokens, is_split_into_words=True).word_ids()
|
|
|
|
| 269 |
for idx, word in enumerate(tokens):
|
| 270 |
beginning_index = word_ids.index(idx)
|
| 271 |
try:
|
|
@@ -701,11 +702,12 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 701 |
truncation=True,
|
| 702 |
max_length=512,
|
| 703 |
)
|
|
|
|
| 704 |
|
| 705 |
-
return tokenized_inputs, text, tokens
|
| 706 |
|
| 707 |
def _forward(self, inputs):
|
| 708 |
-
inputs, text, tokens = inputs
|
| 709 |
|
| 710 |
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
|
| 711 |
self.model.device
|
|
@@ -715,7 +717,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 715 |
)
|
| 716 |
with torch.no_grad():
|
| 717 |
outputs = self.model(input_ids, attention_mask)
|
| 718 |
-
return outputs, text, tokens
|
| 719 |
|
| 720 |
def is_within(self, entity1, entity2):
|
| 721 |
"""Check if entity1 is fully within the bounds of entity2."""
|
|
@@ -731,7 +733,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 731 |
:param kwargs:
|
| 732 |
:return:
|
| 733 |
"""
|
| 734 |
-
tokens_result, text, tokens = outputs
|
| 735 |
|
| 736 |
predictions = {}
|
| 737 |
confidence_scores = {}
|
|
@@ -742,6 +744,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 742 |
entities = {}
|
| 743 |
for task in predictions.keys():
|
| 744 |
words_list, preds_list, confidence_list = realign(
|
|
|
|
| 745 |
tokens,
|
| 746 |
predictions[task],
|
| 747 |
confidence_scores[task],
|
|
|
|
| 262 |
|
| 263 |
|
| 264 |
def realign(
|
| 265 |
+
word_ids, tokens, out_label_preds, softmax_scores, tokenizer, reverted_label_map
|
| 266 |
):
|
| 267 |
preds_list, words_list, confidence_list = [], [], []
|
| 268 |
+
# word_ids = tokenizer(tokens, is_split_into_words=True).word_ids()
|
| 269 |
+
|
| 270 |
for idx, word in enumerate(tokens):
|
| 271 |
beginning_index = word_ids.index(idx)
|
| 272 |
try:
|
|
|
|
| 702 |
truncation=True,
|
| 703 |
max_length=512,
|
| 704 |
)
|
| 705 |
+
word_ids = tokenized_inputs.word_ids()
|
| 706 |
|
| 707 |
+
return tokenized_inputs, word_ids, text, tokens
|
| 708 |
|
| 709 |
def _forward(self, inputs):
|
| 710 |
+
inputs, word_ids, text, tokens = inputs
|
| 711 |
|
| 712 |
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
|
| 713 |
self.model.device
|
|
|
|
| 717 |
)
|
| 718 |
with torch.no_grad():
|
| 719 |
outputs = self.model(input_ids, attention_mask)
|
| 720 |
+
return outputs, word_ids, text, tokens
|
| 721 |
|
| 722 |
def is_within(self, entity1, entity2):
|
| 723 |
"""Check if entity1 is fully within the bounds of entity2."""
|
|
|
|
| 733 |
:param kwargs:
|
| 734 |
:return:
|
| 735 |
"""
|
| 736 |
+
tokens_result, word_ids, text, tokens = outputs
|
| 737 |
|
| 738 |
predictions = {}
|
| 739 |
confidence_scores = {}
|
|
|
|
| 744 |
entities = {}
|
| 745 |
for task in predictions.keys():
|
| 746 |
words_list, preds_list, confidence_list = realign(
|
| 747 |
+
word_ids,
|
| 748 |
tokens,
|
| 749 |
predictions[task],
|
| 750 |
confidence_scores[task],
|