emanuelaboros commited on
Commit
af2b95e
·
1 Parent(s): 88de6a3
Files changed (1) hide show
  1. generic_ner.py +64 -46
generic_ner.py CHANGED
@@ -236,6 +236,21 @@ def attach_comp_to_closest(entities):
236
  return other_entities
237
 
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  class MultitaskTokenClassificationPipeline(Pipeline):
240
 
241
  def _sanitize_parameters(self, **kwargs):
@@ -274,48 +289,48 @@ class MultitaskTokenClassificationPipeline(Pipeline):
274
  """Check if entity1 is fully within the bounds of entity2."""
275
  return entity1["start"] >= entity2["start"] and entity1["end"] <= entity2["end"]
276
 
277
- def postprocess_entities(self, all_entities):
278
-
279
- # Sort entities by start position, then by end position (to handle nested structures)
280
- all_entities.sort(key=lambda x: (x["start"], -x["end"]))
281
-
282
- # Create a new list for final processed entities
283
- final_entities = []
284
-
285
- # Process each entity and check for nesting
286
- for i, entity in enumerate(all_entities):
287
- nested = False
288
-
289
- # Compare the current entity with already processed entities
290
- for parent_entity in final_entities:
291
- if self.is_within(entity, parent_entity):
292
-
293
- # If the current entity is nested, add it as a field in the parent entity
294
- main_field_name = entity["entity"].split(".")[0]
295
- field_name = entity["entity"].split(".")[
296
- -1
297
- ] # Last part of the label as the field
298
- if main_field_name not in parent_entity["entity"]:
299
- # print(
300
- # "main_field_name:",
301
- # main_field_name,
302
- # "parent_entity:",
303
- # parent_entity["entity"],
304
- # )
305
- parent_entity[field_name] = entity["word"]
306
- nested = True
307
- break
308
- else:
309
- nested = True
310
- if "comp" in entity["entity"]:
311
- nested = True
312
- if not nested:
313
- # If not nested, add the entity as a new outermost entity
314
- entity["text"] = entity["word"]
315
- entity.pop("word")
316
- final_entities.append(entity)
317
-
318
- return final_entities
319
 
320
  def postprocess(self, outputs, **kwargs):
321
  """
@@ -355,10 +370,13 @@ class MultitaskTokenClassificationPipeline(Pipeline):
355
  if key not in ["NE-COARSE-LIT"]:
356
  all_entities.extend(entities[key])
357
 
358
- print("Skipping 1")
359
- # all_entities = self.postprocess_entities(all_entities, text_sentence)
360
- # print("After 1:")
361
- # pprint(all_entities)
 
 
 
362
  # Attach "comp.function" entities to the closest non-"comp.function" entity
363
  all_entities = attach_comp_to_closest(all_entities)
364
  print("After 2:")
 
236
  return other_entities
237
 
238
 
239
+ def postprocess_entities(entities):
240
+ # Step 1: Filter entities with the same text, keeping those with the more specific label (contains a dot)
241
+ filtered_entities = []
242
+ entity_map = {}
243
+
244
+ # Loop over the entities and prioritize the more specific ones
245
+ for entity in entities:
246
+ entity_text = entity["text"]
247
+ # If this entity text hasn't been processed, or we find a more specific label, update it
248
+ if entity_text not in entity_map or "." in entity["entity"]:
249
+ entity_map[entity_text] = entity
250
+
251
+ return entity_map
252
+
253
+
254
  class MultitaskTokenClassificationPipeline(Pipeline):
255
 
256
  def _sanitize_parameters(self, **kwargs):
 
289
  """Check if entity1 is fully within the bounds of entity2."""
290
  return entity1["start"] >= entity2["start"] and entity1["end"] <= entity2["end"]
291
 
292
+ # def postprocess_entities(self, all_entities):
293
+ #
294
+ # # Sort entities by start position, then by end position (to handle nested structures)
295
+ # all_entities.sort(key=lambda x: (x["start"], -x["end"]))
296
+ #
297
+ # # Create a new list for final processed entities
298
+ # final_entities = []
299
+ #
300
+ # # Process each entity and check for nesting
301
+ # for i, entity in enumerate(all_entities):
302
+ # nested = False
303
+ #
304
+ # # Compare the current entity with already processed entities
305
+ # for parent_entity in final_entities:
306
+ # if self.is_within(entity, parent_entity):
307
+ #
308
+ # # If the current entity is nested, add it as a field in the parent entity
309
+ # main_field_name = entity["entity"].split(".")[0]
310
+ # field_name = entity["entity"].split(".")[
311
+ # -1
312
+ # ] # Last part of the label as the field
313
+ # if main_field_name not in parent_entity["entity"]:
314
+ # # print(
315
+ # # "main_field_name:",
316
+ # # main_field_name,
317
+ # # "parent_entity:",
318
+ # # parent_entity["entity"],
319
+ # # )
320
+ # parent_entity[field_name] = entity["word"]
321
+ # nested = True
322
+ # break
323
+ # else:
324
+ # nested = True
325
+ # if "comp" in entity["entity"]:
326
+ # nested = True
327
+ # if not nested:
328
+ # # If not nested, add the entity as a new outermost entity
329
+ # entity["text"] = entity["word"]
330
+ # entity.pop("word")
331
+ # final_entities.append(entity)
332
+ #
333
+ # return final_entities
334
 
335
  def postprocess(self, outputs, **kwargs):
336
  """
 
370
  if key not in ["NE-COARSE-LIT"]:
371
  all_entities.extend(entities[key])
372
 
373
+ # print("Skipping 1")
374
+ all_entities = postprocess_entities(
375
+ all_entities,
376
+ )
377
+
378
+ print("After 1:")
379
+ pprint(all_entities)
380
  # Attach "comp.function" entities to the closest non-"comp.function" entity
381
  all_entities = attach_comp_to_closest(all_entities)
382
  print("After 2:")