adding some debug logging for timings
Browse files- modeling.py +16 -2
modeling.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
# This file is part of IceBERT POS model conversion.
|
| 3 |
|
| 4 |
import logging
|
|
|
|
| 5 |
from typing import List, Optional, Tuple
|
| 6 |
|
| 7 |
import torch
|
|
@@ -379,9 +380,19 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 379 |
Returns:
|
| 380 |
List of sequences, each containing (category, [attributes]) per word
|
| 381 |
"""
|
|
|
|
|
|
|
| 382 |
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
|
|
|
|
|
|
|
| 383 |
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
def predict_labels_from_text(
|
| 387 |
self, sentences: List[List[str]], tokenizer, truncate: bool = False
|
|
@@ -434,11 +445,14 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 434 |
# Get model predictions in (category, [attributes]) format
|
| 435 |
predictions = self.predict_labels_from_text(sentences, tokenizer, truncate)
|
| 436 |
|
| 437 |
-
#
|
|
|
|
| 438 |
ifd_predictions = []
|
| 439 |
for sentence_predictions in predictions:
|
| 440 |
ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
|
| 441 |
ifd_predictions.append(ifd_labels)
|
|
|
|
|
|
|
| 442 |
|
| 443 |
return ifd_predictions
|
| 444 |
|
|
|
|
| 2 |
# This file is part of IceBERT POS model conversion.
|
| 3 |
|
| 4 |
import logging
|
| 5 |
+
import time
|
| 6 |
from typing import List, Optional, Tuple
|
| 7 |
|
| 8 |
import torch
|
|
|
|
| 380 |
Returns:
|
| 381 |
List of sequences, each containing (category, [attributes]) per word
|
| 382 |
"""
|
| 383 |
+
# Time the forward pass
|
| 384 |
+
start_time = time.perf_counter()
|
| 385 |
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
|
| 386 |
+
forward_time = time.perf_counter() - start_time
|
| 387 |
+
logger.debug(f"Forward pass took {forward_time:.4f} seconds")
|
| 388 |
|
| 389 |
+
# Time the logits to labels conversion
|
| 390 |
+
start_time = time.perf_counter()
|
| 391 |
+
result = self._logits_to_labels(cat_logits, attr_logits, word_mask)
|
| 392 |
+
logits_to_labels_time = time.perf_counter() - start_time
|
| 393 |
+
logger.debug(f"Logits to labels conversion took {logits_to_labels_time:.4f} seconds")
|
| 394 |
+
|
| 395 |
+
return result
|
| 396 |
|
| 397 |
def predict_labels_from_text(
|
| 398 |
self, sentences: List[List[str]], tokenizer, truncate: bool = False
|
|
|
|
| 445 |
# Get model predictions in (category, [attributes]) format
|
| 446 |
predictions = self.predict_labels_from_text(sentences, tokenizer, truncate)
|
| 447 |
|
| 448 |
+
# Time the IFD conversion
|
| 449 |
+
start_time = time.perf_counter()
|
| 450 |
ifd_predictions = []
|
| 451 |
for sentence_predictions in predictions:
|
| 452 |
ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
|
| 453 |
ifd_predictions.append(ifd_labels)
|
| 454 |
+
ifd_conversion_time = time.perf_counter() - start_time
|
| 455 |
+
logger.debug(f"IFD conversion took {ifd_conversion_time:.4f} seconds")
|
| 456 |
|
| 457 |
return ifd_predictions
|
| 458 |
|