|
|
from collections import deque |
|
|
from time import time |
|
|
|
|
|
import pytorch_lightning as pl |
|
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
|
|
|
|
|
|
class TrainSpeedTimer(pl.Callback): |
|
|
def __init__(self, N_avg=5): |
|
|
""" |
|
|
This callback times the training speed (averge over recent 5 iterations) |
|
|
1. Data waiting time: this should be small, otherwise the data loading should be improved |
|
|
2. Single batch time: this is the time for one batch of training (excluding data waiting) |
|
|
""" |
|
|
super().__init__() |
|
|
self.last_batch_end = None |
|
|
self.this_batch_start = None |
|
|
|
|
|
|
|
|
self.data_waiting_time_queue = deque(maxlen=N_avg) |
|
|
self.single_batch_time_queue = deque(maxlen=N_avg) |
|
|
|
|
|
@rank_zero_only |
|
|
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): |
|
|
"""Count the time of data waiting""" |
|
|
if self.last_batch_end is not None: |
|
|
|
|
|
data_waiting = time() - self.last_batch_end |
|
|
|
|
|
|
|
|
self.data_waiting_time_queue.append(data_waiting) |
|
|
average_time = sum(self.data_waiting_time_queue) / len( |
|
|
self.data_waiting_time_queue |
|
|
) |
|
|
|
|
|
|
|
|
pl_module.log( |
|
|
"train_timer/data_waiting", |
|
|
average_time, |
|
|
on_step=True, |
|
|
on_epoch=False, |
|
|
prog_bar=True, |
|
|
logger=True, |
|
|
) |
|
|
|
|
|
self.this_batch_start = time() |
|
|
|
|
|
@rank_zero_only |
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
|
|
|
|
|
single_batch = time() - self.this_batch_start |
|
|
|
|
|
|
|
|
self.single_batch_time_queue.append(single_batch) |
|
|
average_time = sum(self.single_batch_time_queue) / len( |
|
|
self.single_batch_time_queue |
|
|
) |
|
|
|
|
|
|
|
|
pl_module.log( |
|
|
"train_timer/single_batch", |
|
|
average_time, |
|
|
on_step=True, |
|
|
on_epoch=False, |
|
|
prog_bar=False, |
|
|
logger=True, |
|
|
) |
|
|
|
|
|
|
|
|
self.last_batch_end = time() |
|
|
|
|
|
@rank_zero_only |
|
|
def on_train_epoch_end(self, trainer, pl_module): |
|
|
|
|
|
self.last_batch_end = None |
|
|
self.this_batch_start = None |
|
|
|
|
|
self.data_waiting_time_queue.clear() |
|
|
self.single_batch_time_queue.clear() |
|
|
|