Commit ·
ba34bb6
1
Parent(s): 734888b
add comp back
Browse files- generic_ner.py +43 -2
generic_ner.py
CHANGED
|
@@ -486,6 +486,45 @@ def remove_included_entities(entities):
|
|
| 486 |
return final_entities
|
| 487 |
|
| 488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
def remove_trailing_stopwords(entities):
|
| 490 |
"""
|
| 491 |
This function removes stopwords and punctuation from both the beginning and end of each entity's text
|
|
@@ -715,8 +754,10 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 715 |
# pprint(entities)
|
| 716 |
|
| 717 |
all_entities = []
|
|
|
|
| 718 |
for key in entities:
|
| 719 |
-
|
|
|
|
| 720 |
all_entities.extend(entities[key])
|
| 721 |
|
| 722 |
if DEBUG:
|
|
@@ -725,7 +766,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 725 |
all_entities = remove_included_entities(all_entities)
|
| 726 |
all_entities = remove_trailing_stopwords(all_entities)
|
| 727 |
all_entities = postprocess_entities(all_entities)
|
| 728 |
-
|
| 729 |
# print("After attach_comp_to_closest:")
|
| 730 |
# pprint(all_entities)
|
| 731 |
# print("\n")
|
|
|
|
| 486 |
return final_entities
|
| 487 |
|
| 488 |
|
| 489 |
+
def refine_entities_with_coarse(all_entities, coarse_entities):
|
| 490 |
+
"""
|
| 491 |
+
Looks through all entities and refines them based on the coarse entities.
|
| 492 |
+
If a surface match is found in the coarse entities and the types match,
|
| 493 |
+
the entity with the higher confidence_ner is kept.
|
| 494 |
+
"""
|
| 495 |
+
refined_entities = []
|
| 496 |
+
|
| 497 |
+
# Create a dictionary for coarse entities based on surface and type for quick lookup
|
| 498 |
+
coarse_lookup = {}
|
| 499 |
+
for coarse_entity in coarse_entities:
|
| 500 |
+
key = (coarse_entity["surface"], coarse_entity["type"])
|
| 501 |
+
coarse_lookup[key] = coarse_entity
|
| 502 |
+
|
| 503 |
+
# Iterate through all entities and compare with the coarse entities
|
| 504 |
+
for entity in all_entities:
|
| 505 |
+
key = (
|
| 506 |
+
entity["surface"],
|
| 507 |
+
entity["type"].split(".")[0],
|
| 508 |
+
) # Use the coarse type for comparison
|
| 509 |
+
|
| 510 |
+
if key in coarse_lookup:
|
| 511 |
+
# If the types match, compare confidence_ner and keep the one with the higher confidence
|
| 512 |
+
coarse_entity = coarse_lookup[key]
|
| 513 |
+
if entity["confidence_ner"] > coarse_entity["confidence_ner"]:
|
| 514 |
+
refined_entities.append(
|
| 515 |
+
entity
|
| 516 |
+
) # Keep the current entity with higher confidence
|
| 517 |
+
else:
|
| 518 |
+
refined_entities.append(
|
| 519 |
+
coarse_entity
|
| 520 |
+
) # Keep the coarse entity with higher confidence
|
| 521 |
+
else:
|
| 522 |
+
# If no match in coarse, just add the entity to refined entities
|
| 523 |
+
refined_entities.append(entity)
|
| 524 |
+
|
| 525 |
+
return refined_entities
|
| 526 |
+
|
| 527 |
+
|
| 528 |
def remove_trailing_stopwords(entities):
|
| 529 |
"""
|
| 530 |
This function removes stopwords and punctuation from both the beginning and end of each entity's text
|
|
|
|
| 754 |
# pprint(entities)
|
| 755 |
|
| 756 |
all_entities = []
|
| 757 |
+
coarse_entities = []
|
| 758 |
for key in entities:
|
| 759 |
+
if key in ["NE-COARSE-LIT"]:
|
| 760 |
+
coarse_entities = entities[key]
|
| 761 |
all_entities.extend(entities[key])
|
| 762 |
|
| 763 |
if DEBUG:
|
|
|
|
| 766 |
all_entities = remove_included_entities(all_entities)
|
| 767 |
all_entities = remove_trailing_stopwords(all_entities)
|
| 768 |
all_entities = postprocess_entities(all_entities)
|
| 769 |
+
all_entities = refine_entities_with_coarse(all_entities, coarse_entities)
|
| 770 |
# print("After attach_comp_to_closest:")
|
| 771 |
# pprint(all_entities)
|
| 772 |
# print("\n")
|