emanuelaboros commited on
Commit
ba34bb6
·
1 Parent(s): 734888b

add comp back

Browse files
Files changed (1) hide show
  1. generic_ner.py +43 -2
generic_ner.py CHANGED
@@ -486,6 +486,45 @@ def remove_included_entities(entities):
486
  return final_entities
487
 
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  def remove_trailing_stopwords(entities):
490
  """
491
  This function removes stopwords and punctuation from both the beginning and end of each entity's text
@@ -715,8 +754,10 @@ class MultitaskTokenClassificationPipeline(Pipeline):
715
  # pprint(entities)
716
 
717
  all_entities = []
 
718
  for key in entities:
719
- # if key not in ["NE-COARSE-LIT"]:
 
720
  all_entities.extend(entities[key])
721
 
722
  if DEBUG:
@@ -725,7 +766,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
725
  all_entities = remove_included_entities(all_entities)
726
  all_entities = remove_trailing_stopwords(all_entities)
727
  all_entities = postprocess_entities(all_entities)
728
-
729
  # print("After attach_comp_to_closest:")
730
  # pprint(all_entities)
731
  # print("\n")
 
486
  return final_entities
487
 
488
 
489
+ def refine_entities_with_coarse(all_entities, coarse_entities):
490
+ """
491
+ Looks through all entities and refines them based on the coarse entities.
492
+ If a surface match is found in the coarse entities and the types match,
493
+ the entity with the higher confidence_ner is kept.
494
+ """
495
+ refined_entities = []
496
+
497
+ # Create a dictionary for coarse entities based on surface and type for quick lookup
498
+ coarse_lookup = {}
499
+ for coarse_entity in coarse_entities:
500
+ key = (coarse_entity["surface"], coarse_entity["type"])
501
+ coarse_lookup[key] = coarse_entity
502
+
503
+ # Iterate through all entities and compare with the coarse entities
504
+ for entity in all_entities:
505
+ key = (
506
+ entity["surface"],
507
+ entity["type"].split(".")[0],
508
+ ) # Use the coarse type for comparison
509
+
510
+ if key in coarse_lookup:
511
+ # If the types match, compare confidence_ner and keep the one with the higher confidence
512
+ coarse_entity = coarse_lookup[key]
513
+ if entity["confidence_ner"] > coarse_entity["confidence_ner"]:
514
+ refined_entities.append(
515
+ entity
516
+ ) # Keep the current entity with higher confidence
517
+ else:
518
+ refined_entities.append(
519
+ coarse_entity
520
+ ) # Keep the coarse entity with higher confidence
521
+ else:
522
+ # If no match in coarse, just add the entity to refined entities
523
+ refined_entities.append(entity)
524
+
525
+ return refined_entities
526
+
527
+
528
  def remove_trailing_stopwords(entities):
529
  """
530
  This function removes stopwords and punctuation from both the beginning and end of each entity's text
 
754
  # pprint(entities)
755
 
756
  all_entities = []
757
+ coarse_entities = []
758
  for key in entities:
759
+ if key in ["NE-COARSE-LIT"]:
760
+ coarse_entities = entities[key]
761
  all_entities.extend(entities[key])
762
 
763
  if DEBUG:
 
766
  all_entities = remove_included_entities(all_entities)
767
  all_entities = remove_trailing_stopwords(all_entities)
768
  all_entities = postprocess_entities(all_entities)
769
+ all_entities = refine_entities_with_coarse(all_entities, coarse_entities)
770
  # print("After attach_comp_to_closest:")
771
  # pprint(all_entities)
772
  # print("\n")