emanuelaboros commited on
Commit
1f946e5
·
1 Parent(s): e8bd3ea
Files changed (1) hide show
  1. generic_ner.py +23 -1
generic_ner.py CHANGED
@@ -8,7 +8,6 @@ nltk.download("averaged_perceptron_tagger_eng")
8
  from nltk.chunk import conlltags2tree
9
  from nltk import pos_tag
10
  from nltk.tree import Tree
11
- import string
12
  import torch.nn.functional as F
13
  import re, string
14
 
@@ -273,6 +272,26 @@ def remove_included_entities(entities):
273
  return final_entities
274
 
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  class MultitaskTokenClassificationPipeline(Pipeline):
277
 
278
  def _sanitize_parameters(self, **kwargs):
@@ -402,6 +421,9 @@ class MultitaskTokenClassificationPipeline(Pipeline):
402
  )
403
  pprint(all_entities)
404
 
 
 
 
405
  # Attach "comp.function" entities to the closest non-"comp.function" entity
406
  all_entities = attach_comp_to_closest(all_entities)
407
  print("After attach_comp_to_closest:")
 
8
  from nltk.chunk import conlltags2tree
9
  from nltk import pos_tag
10
  from nltk.tree import Tree
 
11
  import torch.nn.functional as F
12
  import re, string
13
 
 
272
  return final_entities
273
 
274
 
275
+ from stopwordsiso import stopwords
276
+
277
+ stop_words = stopwords(["en", "fr", "de"])
278
+
279
+
280
+ def remove_trailing_stopwords(entities):
281
+ # This function removes stopwords from the end of each entity's text
282
+ for entity in entities:
283
+ words = entity["text"].split()
284
+
285
+ # Continue removing stopwords from the end of the text
286
+ while words and words[-1].lower() in stop_words:
287
+ words.pop() # Remove the last word if it's a stopword
288
+
289
+ # Join the words back together and update the entity's text
290
+ entity["text"] = " ".join(words)
291
+
292
+ return entities
293
+
294
+
295
  class MultitaskTokenClassificationPipeline(Pipeline):
296
 
297
  def _sanitize_parameters(self, **kwargs):
 
421
  )
422
  pprint(all_entities)
423
 
424
+ all_entities = remove_trailing_stopwords(all_entities)
425
+ print("After remove_trailing_stopwords:")
426
+ pprint(all_entities)
427
  # Attach "comp.function" entities to the closest non-"comp.function" entity
428
  all_entities = attach_comp_to_closest(all_entities)
429
  print("After attach_comp_to_closest:")