hmr-dataset / genmo /callbacks /train_speed_timer.py
zirobtc's picture
Upload folder using huggingface_hub
fbb20ff verified
from collections import deque
from time import time
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
class TrainSpeedTimer(pl.Callback):
def __init__(self, N_avg=5):
"""
This callback times the training speed (averge over recent 5 iterations)
1. Data waiting time: this should be small, otherwise the data loading should be improved
2. Single batch time: this is the time for one batch of training (excluding data waiting)
"""
super().__init__()
self.last_batch_end = None
self.this_batch_start = None
# time queues for averaging
self.data_waiting_time_queue = deque(maxlen=N_avg)
self.single_batch_time_queue = deque(maxlen=N_avg)
@rank_zero_only
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
"""Count the time of data waiting"""
if self.last_batch_end is not None:
# This should be small, otherwise the data loading should be improved
data_waiting = time() - self.last_batch_end
# Average the time
self.data_waiting_time_queue.append(data_waiting)
average_time = sum(self.data_waiting_time_queue) / len(
self.data_waiting_time_queue
)
# Log to prog-bar
pl_module.log(
"train_timer/data_waiting",
average_time,
on_step=True,
on_epoch=False,
prog_bar=True,
logger=True,
)
self.this_batch_start = time()
@rank_zero_only
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Effective training time elapsed (excluding data waiting)
single_batch = time() - self.this_batch_start
# Average the time
self.single_batch_time_queue.append(single_batch)
average_time = sum(self.single_batch_time_queue) / len(
self.single_batch_time_queue
)
# Log iter time
pl_module.log(
"train_timer/single_batch",
average_time,
on_step=True,
on_epoch=False,
prog_bar=False,
logger=True,
)
# Set timer for counting data waiting
self.last_batch_end = time()
@rank_zero_only
def on_train_epoch_end(self, trainer, pl_module):
# Reset the timer
self.last_batch_end = None
self.this_batch_start = None
# Clear the queue
self.data_waiting_time_queue.clear()
self.single_batch_time_queue.clear()