emanuelaboros commited on
Commit
d7edcb3
·
1 Parent(s): 3776ec2
Files changed (1) hide show
  1. generic_ner.py +2 -58
generic_ner.py CHANGED
@@ -712,19 +712,8 @@ class MultitaskTokenClassificationPipeline(Pipeline):
712
 
713
  def _forward(self, inputs):
714
  inputs, text_sentences, text = inputs
715
- input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
716
- self.model.device
717
- )
718
- attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
719
- self.model.device
720
- )
721
- # print(f"Let's check the model: {self.model}")
722
- # check get floret model
723
-
724
- with torch.no_grad():
725
- outputs = self.model(input_ids, attention_mask)
726
 
727
- return outputs, text_sentences, text
728
 
729
  def is_within(self, entity1, entity2):
730
  """Check if entity1 is fully within the bounds of entity2."""
@@ -733,58 +722,13 @@ class MultitaskTokenClassificationPipeline(Pipeline):
733
  and entity1["rOffset"] <= entity2["rOffset"]
734
  )
735
 
736
- def postprocess(self, outputs, **kwargs):
737
  """
738
  Postprocess the outputs of the model
739
  :param outputs:
740
  :param kwargs:
741
  :return:
742
  """
743
- tokens_result, text_sentence, text = outputs
744
-
745
- predictions = {}
746
- confidence_scores = {}
747
- for task, logits in tokens_result.logits.items():
748
- predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
749
- confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]
750
-
751
- entities = {}
752
- for task in predictions.keys():
753
- words_list, preds_list, confidence_list = realign(
754
- text_sentence,
755
- predictions[task],
756
- confidence_scores[task],
757
- self.tokenizer,
758
- self.id2label[task],
759
- )
760
-
761
- entities[task] = get_entities(words_list, preds_list, confidence_list, text)
762
-
763
- all_entities = []
764
- coarse_entities = []
765
- for key in entities:
766
- if key in ["NE-COARSE-LIT"]:
767
- coarse_entities = entities[key]
768
- all_entities.extend(entities[key])
769
-
770
- if DEBUG:
771
- print(all_entities)
772
- # print("After remove_included_entities:")
773
- all_entities = remove_included_entities(all_entities)
774
- if DEBUG:
775
- print("After remove_included_entities:", all_entities)
776
- all_entities = remove_trailing_stopwords(all_entities)
777
- if DEBUG:
778
- print("After remove_trailing_stopwords:", all_entities)
779
- all_entities = postprocess_entities(all_entities)
780
- if DEBUG:
781
- print("After postprocess_entities:", all_entities)
782
- all_entities = refine_entities_with_coarse(all_entities, coarse_entities)
783
- if DEBUG:
784
- print("After refine_entities_with_coarse:", all_entities)
785
- # print("After attach_comp_to_closest:")
786
- # pprint(all_entities)
787
- # print("\n")
788
 
789
  print(f"Let's check the model: {self.model.get_floret_model()}")
790
  predictions, probabilities = self.model.get_floret_model().predict([text], k=1)
 
712
 
713
  def _forward(self, inputs):
714
  inputs, text_sentences, text = inputs
 
 
 
 
 
 
 
 
 
 
 
715
 
716
+ return text
717
 
718
  def is_within(self, entity1, entity2):
719
  """Check if entity1 is fully within the bounds of entity2."""
 
722
  and entity1["rOffset"] <= entity2["rOffset"]
723
  )
724
 
725
+ def postprocess(self, text, **kwargs):
726
  """
727
  Postprocess the outputs of the model
728
  :param outputs:
729
  :param kwargs:
730
  :return:
731
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
 
733
  print(f"Let's check the model: {self.model.get_floret_model()}")
734
  predictions, probabilities = self.model.get_floret_model().predict([text], k=1)