| from collections import deque |
| from time import time |
|
|
| import pytorch_lightning as pl |
| from pytorch_lightning.utilities import rank_zero_only |
|
|
|
|
| class TrainSpeedTimer(pl.Callback): |
| def __init__(self, N_avg=5): |
| """ |
| This callback times the training speed (averge over recent 5 iterations) |
| 1. Data waiting time: this should be small, otherwise the data loading should be improved |
| 2. Single batch time: this is the time for one batch of training (excluding data waiting) |
| """ |
| super().__init__() |
| self.last_batch_end = None |
| self.this_batch_start = None |
|
|
| |
| self.data_waiting_time_queue = deque(maxlen=N_avg) |
| self.single_batch_time_queue = deque(maxlen=N_avg) |
|
|
| @rank_zero_only |
| def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): |
| """Count the time of data waiting""" |
| if self.last_batch_end is not None: |
| |
| data_waiting = time() - self.last_batch_end |
|
|
| |
| self.data_waiting_time_queue.append(data_waiting) |
| average_time = sum(self.data_waiting_time_queue) / len( |
| self.data_waiting_time_queue |
| ) |
|
|
| |
| pl_module.log( |
| "train_timer/data_waiting", |
| average_time, |
| on_step=True, |
| on_epoch=False, |
| prog_bar=True, |
| logger=True, |
| ) |
|
|
| self.this_batch_start = time() |
|
|
| @rank_zero_only |
| def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
| |
| single_batch = time() - self.this_batch_start |
|
|
| |
| self.single_batch_time_queue.append(single_batch) |
| average_time = sum(self.single_batch_time_queue) / len( |
| self.single_batch_time_queue |
| ) |
|
|
| |
| pl_module.log( |
| "train_timer/single_batch", |
| average_time, |
| on_step=True, |
| on_epoch=False, |
| prog_bar=False, |
| logger=True, |
| ) |
|
|
| |
| self.last_batch_end = time() |
|
|
| @rank_zero_only |
| def on_train_epoch_end(self, trainer, pl_module): |
| |
| self.last_batch_end = None |
| self.this_batch_start = None |
| |
| self.data_waiting_time_queue.clear() |
| self.single_batch_time_queue.clear() |
|
|