Commit
·
1f946e5
1
Parent(s):
e8bd3ea
degbu
Browse files- generic_ner.py +23 -1
generic_ner.py
CHANGED
|
@@ -8,7 +8,6 @@ nltk.download("averaged_perceptron_tagger_eng")
|
|
| 8 |
from nltk.chunk import conlltags2tree
|
| 9 |
from nltk import pos_tag
|
| 10 |
from nltk.tree import Tree
|
| 11 |
-
import string
|
| 12 |
import torch.nn.functional as F
|
| 13 |
import re, string
|
| 14 |
|
|
@@ -273,6 +272,26 @@ def remove_included_entities(entities):
|
|
| 273 |
return final_entities
|
| 274 |
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
class MultitaskTokenClassificationPipeline(Pipeline):
|
| 277 |
|
| 278 |
def _sanitize_parameters(self, **kwargs):
|
|
@@ -402,6 +421,9 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 402 |
)
|
| 403 |
pprint(all_entities)
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
# Attach "comp.function" entities to the closest non-"comp.function" entity
|
| 406 |
all_entities = attach_comp_to_closest(all_entities)
|
| 407 |
print("After attach_comp_to_closest:")
|
|
|
|
| 8 |
from nltk.chunk import conlltags2tree
|
| 9 |
from nltk import pos_tag
|
| 10 |
from nltk.tree import Tree
|
|
|
|
| 11 |
import torch.nn.functional as F
|
| 12 |
import re, string
|
| 13 |
|
|
|
|
| 272 |
return final_entities
|
| 273 |
|
| 274 |
|
| 275 |
+
from stopwordsiso import stopwords
|
| 276 |
+
|
| 277 |
+
stop_words = stopwords(["en", "fr", "de"])
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def remove_trailing_stopwords(entities):
|
| 281 |
+
# This function removes stopwords from the end of each entity's text
|
| 282 |
+
for entity in entities:
|
| 283 |
+
words = entity["text"].split()
|
| 284 |
+
|
| 285 |
+
# Continue removing stopwords from the end of the text
|
| 286 |
+
while words and words[-1].lower() in stop_words:
|
| 287 |
+
words.pop() # Remove the last word if it's a stopword
|
| 288 |
+
|
| 289 |
+
# Join the words back together and update the entity's text
|
| 290 |
+
entity["text"] = " ".join(words)
|
| 291 |
+
|
| 292 |
+
return entities
|
| 293 |
+
|
| 294 |
+
|
| 295 |
class MultitaskTokenClassificationPipeline(Pipeline):
|
| 296 |
|
| 297 |
def _sanitize_parameters(self, **kwargs):
|
|
|
|
| 421 |
)
|
| 422 |
pprint(all_entities)
|
| 423 |
|
| 424 |
+
all_entities = remove_trailing_stopwords(all_entities)
|
| 425 |
+
print("After remove_trailing_stopwords:")
|
| 426 |
+
pprint(all_entities)
|
| 427 |
# Attach "comp.function" entities to the closest non-"comp.function" entity
|
| 428 |
all_entities = attach_comp_to_closest(all_entities)
|
| 429 |
print("After attach_comp_to_closest:")
|