haukurpj commited on
Commit
4ede3eb
·
1 Parent(s): 3d90a1f

adding some debug logging for timings

Browse files
Files changed (1) hide show
  1. modeling.py +16 -2
modeling.py CHANGED
@@ -2,6 +2,7 @@
2
  # This file is part of IceBERT POS model conversion.
3
 
4
  import logging
 
5
  from typing import List, Optional, Tuple
6
 
7
  import torch
@@ -379,9 +380,19 @@ class IceBertPosForTokenClassification(PreTrainedModel):
379
  Returns:
380
  List of sequences, each containing (category, [attributes]) per word
381
  """
 
 
382
  cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
 
 
383
 
384
- return self._logits_to_labels(cat_logits, attr_logits, word_mask)
 
 
 
 
 
 
385
 
386
  def predict_labels_from_text(
387
  self, sentences: List[List[str]], tokenizer, truncate: bool = False
@@ -434,11 +445,14 @@ class IceBertPosForTokenClassification(PreTrainedModel):
434
  # Get model predictions in (category, [attributes]) format
435
  predictions = self.predict_labels_from_text(sentences, tokenizer, truncate)
436
 
437
- # Convert each sentence's predictions to IFD format
 
438
  ifd_predictions = []
439
  for sentence_predictions in predictions:
440
  ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
441
  ifd_predictions.append(ifd_labels)
 
 
442
 
443
  return ifd_predictions
444
 
 
2
  # This file is part of IceBERT POS model conversion.
3
 
4
  import logging
5
+ import time
6
  from typing import List, Optional, Tuple
7
 
8
  import torch
 
380
  Returns:
381
  List of sequences, each containing (category, [attributes]) per word
382
  """
383
+ # Time the forward pass
384
+ start_time = time.perf_counter()
385
  cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
386
+ forward_time = time.perf_counter() - start_time
387
+ logger.debug(f"Forward pass took {forward_time:.4f} seconds")
388
 
389
+ # Time the logits to labels conversion
390
+ start_time = time.perf_counter()
391
+ result = self._logits_to_labels(cat_logits, attr_logits, word_mask)
392
+ logits_to_labels_time = time.perf_counter() - start_time
393
+ logger.debug(f"Logits to labels conversion took {logits_to_labels_time:.4f} seconds")
394
+
395
+ return result
396
 
397
  def predict_labels_from_text(
398
  self, sentences: List[List[str]], tokenizer, truncate: bool = False
 
445
  # Get model predictions in (category, [attributes]) format
446
  predictions = self.predict_labels_from_text(sentences, tokenizer, truncate)
447
 
448
+ # Time the IFD conversion
449
+ start_time = time.perf_counter()
450
  ifd_predictions = []
451
  for sentence_predictions in predictions:
452
  ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
453
  ifd_predictions.append(ifd_labels)
454
+ ifd_conversion_time = time.perf_counter() - start_time
455
+ logger.debug(f"IFD conversion took {ifd_conversion_time:.4f} seconds")
456
 
457
  return ifd_predictions
458