emanuelaboros commited on
Commit
6deb831
·
1 Parent(s): ca874c3

move to pregenerated tokens - some bug with word ids -- move to the inital ones

Browse files
Files changed (1) hide show
  1. generic_ner.py +9 -6
generic_ner.py CHANGED
@@ -262,10 +262,11 @@ def get_entities(tokens, tags, confidences, text):
262
 
263
 
264
  def realign(
265
- tokens, out_label_preds, softmax_scores, tokenizer, reverted_label_map
266
  ):
267
  preds_list, words_list, confidence_list = [], [], []
268
- word_ids = tokenizer(tokens, is_split_into_words=True).word_ids()
 
269
  for idx, word in enumerate(tokens):
270
  beginning_index = word_ids.index(idx)
271
  try:
@@ -701,11 +702,12 @@ class MultitaskTokenClassificationPipeline(Pipeline):
701
  truncation=True,
702
  max_length=512,
703
  )
 
704
 
705
- return tokenized_inputs, text, tokens
706
 
707
  def _forward(self, inputs):
708
- inputs, text, tokens = inputs
709
 
710
  input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
711
  self.model.device
@@ -715,7 +717,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
715
  )
716
  with torch.no_grad():
717
  outputs = self.model(input_ids, attention_mask)
718
- return outputs, text, tokens
719
 
720
  def is_within(self, entity1, entity2):
721
  """Check if entity1 is fully within the bounds of entity2."""
@@ -731,7 +733,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
731
  :param kwargs:
732
  :return:
733
  """
734
- tokens_result, text, tokens = outputs
735
 
736
  predictions = {}
737
  confidence_scores = {}
@@ -742,6 +744,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
742
  entities = {}
743
  for task in predictions.keys():
744
  words_list, preds_list, confidence_list = realign(
 
745
  tokens,
746
  predictions[task],
747
  confidence_scores[task],
 
262
 
263
 
264
  def realign(
265
+ word_ids, tokens, out_label_preds, softmax_scores, tokenizer, reverted_label_map
266
  ):
267
  preds_list, words_list, confidence_list = [], [], []
268
+ # word_ids = tokenizer(tokens, is_split_into_words=True).word_ids()
269
+
270
  for idx, word in enumerate(tokens):
271
  beginning_index = word_ids.index(idx)
272
  try:
 
702
  truncation=True,
703
  max_length=512,
704
  )
705
+ word_ids = tokenized_inputs.word_ids()
706
 
707
+ return tokenized_inputs, word_ids, text, tokens
708
 
709
  def _forward(self, inputs):
710
+ inputs, word_ids, text, tokens = inputs
711
 
712
  input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
713
  self.model.device
 
717
  )
718
  with torch.no_grad():
719
  outputs = self.model(input_ids, attention_mask)
720
+ return outputs, word_ids, text, tokens
721
 
722
  def is_within(self, entity1, entity2):
723
  """Check if entity1 is fully within the bounds of entity2."""
 
733
  :param kwargs:
734
  :return:
735
  """
736
+ tokens_result, word_ids, text, tokens = outputs
737
 
738
  predictions = {}
739
  confidence_scores = {}
 
744
  entities = {}
745
  for task in predictions.keys():
746
  words_list, preds_list, confidence_list = realign(
747
+ word_ids,
748
  tokens,
749
  predictions[task],
750
  confidence_scores[task],