haukurpj commited on
Commit
aa55a31
·
1 Parent(s): 4ede3eb

adding the ifd conversion as a method on the model

Browse files
Files changed (1) hide show
  1. modeling.py +20 -11
modeling.py CHANGED
@@ -426,6 +426,25 @@ class IceBertPosForTokenClassification(PreTrainedModel):
426
 
427
  return self.predict_labels(batch_input_ids, batch_attention_mask, batch_word_mask)
428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  def predict_ifd_labels_from_text(
430
  self, sentences: List[List[str]], tokenizer, truncate: bool = False
431
  ) -> List[List[str]]:
@@ -444,17 +463,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
444
  """
445
  # Get model predictions in (category, [attributes]) format
446
  predictions = self.predict_labels_from_text(sentences, tokenizer, truncate)
447
-
448
- # Time the IFD conversion
449
- start_time = time.perf_counter()
450
- ifd_predictions = []
451
- for sentence_predictions in predictions:
452
- ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
453
- ifd_predictions.append(ifd_labels)
454
- ifd_conversion_time = time.perf_counter() - start_time
455
- logger.debug(f"IFD conversion took {ifd_conversion_time:.4f} seconds")
456
-
457
- return ifd_predictions
458
 
459
  def _word_ids_to_word_mask(self, word_ids: List[int]) -> torch.Tensor:
460
  """
 
426
 
427
  return self.predict_labels(batch_input_ids, batch_attention_mask, batch_word_mask)
428
 
429
+ def convert_labels_to_ifd(self, predictions: List[List[Tuple[str, List[str]]]]) -> List[List[str]]:
430
+ """
431
+ Convert model predictions to IFD format labels.
432
+
433
+ Args:
434
+ predictions: List of sequences, each containing (category, [attributes]) per word
435
+
436
+ Returns:
437
+ List of IFD format labels per sentence
438
+ """
439
+ # Time the IFD conversion
440
+ start_time = time.perf_counter()
441
+ ifd_labels = []
442
+ for sentence_predictions in predictions:
443
+ ifd_labels.append(convert_predictions_to_ifd(sentence_predictions))
444
+ ifd_conversion_time = time.perf_counter() - start_time
445
+ logger.debug(f"IFD conversion took {ifd_conversion_time:.4f} seconds")
446
+ return ifd_labels
447
+
448
  def predict_ifd_labels_from_text(
449
  self, sentences: List[List[str]], tokenizer, truncate: bool = False
450
  ) -> List[List[str]]:
 
463
  """
464
  # Get model predictions in (category, [attributes]) format
465
  predictions = self.predict_labels_from_text(sentences, tokenizer, truncate)
466
+ return self.convert_labels_to_ifd(predictions)
 
 
 
 
 
 
 
 
 
 
467
 
468
  def _word_ids_to_word_mask(self, word_ids: List[int]) -> torch.Tensor:
469
  """