Commit
·
c549c79
1
Parent(s):
4efcbf3
update handler
Browse files- generic_ner.py +48 -4
generic_ner.py
CHANGED
|
@@ -2,8 +2,9 @@ from transformers import Pipeline
|
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
import nltk
|
| 5 |
-
|
| 6 |
-
nltk.download(
|
|
|
|
| 7 |
from nltk.chunk import conlltags2tree
|
| 8 |
from nltk import pos_tag
|
| 9 |
from nltk.tree import Tree
|
|
@@ -107,9 +108,13 @@ def get_entities(tokens, tags, confidences, text):
|
|
| 107 |
entities.append(
|
| 108 |
{
|
| 109 |
"entity": original_label,
|
| 110 |
-
"score": round(
|
|
|
|
|
|
|
| 111 |
"index": (idx, idx + len(subtree)),
|
| 112 |
-
"word": text[
|
|
|
|
|
|
|
| 113 |
"start": entity_start_position,
|
| 114 |
"end": entity_end_position,
|
| 115 |
}
|
|
@@ -221,6 +226,44 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 221 |
outputs = self.model(input_ids, attention_mask)
|
| 222 |
return outputs, text_sentences, text
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
def postprocess(self, outputs, **kwargs):
|
| 226 |
"""
|
|
@@ -249,4 +292,5 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 249 |
|
| 250 |
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
| 251 |
|
|
|
|
| 252 |
return entities
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
import nltk
|
| 5 |
+
|
| 6 |
+
nltk.download("averaged_perceptron_tagger")
|
| 7 |
+
nltk.download("averaged_perceptron_tagger_eng")
|
| 8 |
from nltk.chunk import conlltags2tree
|
| 9 |
from nltk import pos_tag
|
| 10 |
from nltk.tree import Tree
|
|
|
|
| 108 |
entities.append(
|
| 109 |
{
|
| 110 |
"entity": original_label,
|
| 111 |
+
"score": round(
|
| 112 |
+
np.average(confidences[idx : idx + len(subtree)]) * 100, 2
|
| 113 |
+
),
|
| 114 |
"index": (idx, idx + len(subtree)),
|
| 115 |
+
"word": text[
|
| 116 |
+
entity_start_position:entity_end_position
|
| 117 |
+
], # original_string,
|
| 118 |
"start": entity_start_position,
|
| 119 |
"end": entity_end_position,
|
| 120 |
}
|
|
|
|
| 226 |
outputs = self.model(input_ids, attention_mask)
|
| 227 |
return outputs, text_sentences, text
|
| 228 |
|
| 229 |
+
def is_within(self, entity1, entity2):
|
| 230 |
+
"""Check if entity1 is fully within the bounds of entity2."""
|
| 231 |
+
return entity1["start"] >= entity2["start"] and entity1["end"] <= entity2["end"]
|
| 232 |
+
|
| 233 |
+
def postprocess_entities(self, ner_results):
|
| 234 |
+
# Collect all entities in one list for processing
|
| 235 |
+
all_entities = []
|
| 236 |
+
for key in ner_results:
|
| 237 |
+
all_entities.extend(ner_results[key])
|
| 238 |
+
|
| 239 |
+
# Sort entities by start position, then by end position (to handle nested structures)
|
| 240 |
+
all_entities.sort(key=lambda x: (x["start"], -x["end"]))
|
| 241 |
+
|
| 242 |
+
# Create a new list for final processed entities
|
| 243 |
+
final_entities = []
|
| 244 |
+
|
| 245 |
+
# Process each entity and check for nesting
|
| 246 |
+
for i, entity in enumerate(all_entities):
|
| 247 |
+
nested = False
|
| 248 |
+
|
| 249 |
+
# Compare the current entity with already processed entities
|
| 250 |
+
for parent_entity in final_entities:
|
| 251 |
+
if self.is_within(entity, parent_entity):
|
| 252 |
+
# If the current entity is nested, add it as a field in the parent entity
|
| 253 |
+
field_name = entity["entity"].split(".")[
|
| 254 |
+
-1
|
| 255 |
+
] # Last part of the label as the field
|
| 256 |
+
if field_name not in parent_entity:
|
| 257 |
+
parent_entity[field_name] = []
|
| 258 |
+
parent_entity[field_name].append(entity)
|
| 259 |
+
nested = True
|
| 260 |
+
break
|
| 261 |
+
|
| 262 |
+
if not nested:
|
| 263 |
+
# If not nested, add the entity as a new outermost entity
|
| 264 |
+
final_entities.append(entity)
|
| 265 |
+
|
| 266 |
+
return final_entities
|
| 267 |
|
| 268 |
def postprocess(self, outputs, **kwargs):
|
| 269 |
"""
|
|
|
|
| 292 |
|
| 293 |
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
| 294 |
|
| 295 |
+
print(self.postprocess_entities(entities))
|
| 296 |
return entities
|