| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import time |
| | from typing import Optional |
| |
|
| | from lightning.pytorch import Callback, LightningModule, Trainer |
| |
|
| | from nemo.utils import logging |
| |
|
| |
|
| | class BenchmarkCallback(Callback): |
| | def __init__( |
| | self, |
| | start_benchmark_at_step: int = 0, |
| | stop_benchmark_at_step: Optional[int] = None, |
| | log_every_n_steps: int = 10, |
| | ): |
| | super().__init__() |
| | self.start_benchmark_at_step = start_benchmark_at_step |
| | self.stop_benchmark_at_step = stop_benchmark_at_step |
| | self.log_every_n_steps = log_every_n_steps |
| | self.train_times = [] |
| | self.val_times = [] |
| | self.train_steps_times = [] |
| | self.val_steps_times = [] |
| |
|
| | def should_benchmark(self, trainer: Trainer): |
| | if self.stop_benchmark_at_step is None: |
| | return trainer.global_step >= self.start_benchmark_at_step |
| | return self.start_benchmark_at_step <= trainer.global_step <= self.stop_benchmark_at_step |
| |
|
| | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule): |
| | self.epoch_start_time = time.time() |
| |
|
| | def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): |
| | if self.should_benchmark(trainer): |
| | epoch_time = time.time() - self.epoch_start_time |
| | self.train_times.append(epoch_time) |
| | logging.info(f'Training-Epoch-{trainer.current_epoch}-Time: {epoch_time} [sec]') |
| |
|
| | def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int): |
| | self.step_start_time = time.time() |
| |
|
| | def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int): |
| | if self.should_benchmark(trainer): |
| | step_time = time.time() - self.step_start_time |
| | self.train_steps_times.append(step_time) |
| | if trainer.global_step % self.log_every_n_steps == 0: |
| | logging.info(f'Training-Step-{trainer.global_step}-Time: {step_time} [sec]') |
| |
|
| | def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule): |
| | self.val_start_time = time.time() |
| |
|
| | def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule): |
| | if self.should_benchmark(trainer): |
| | val_time = time.time() - self.val_start_time |
| | self.val_times.append(val_time) |
| | logging.info(f'Validation-Epoch-{trainer.current_epoch}-Time: {val_time} [sec]') |
| |
|
| | def on_validation_batch_start( |
| | self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int, dataloader_idx: int |
| | ): |
| | self.val_step_start_time = time.time() |
| |
|
| | def on_validation_batch_end( |
| | self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int, dataloader_idx: int |
| | ): |
| | if self.should_benchmark(trainer): |
| | val_step_time = time.time() - self.val_step_start_time |
| | self.val_steps_times.append(val_step_time) |
| | if trainer.global_step % self.log_every_n_steps == 0: |
| | logging.info(f'Validation-Step-{trainer.global_step}-Time: {val_step_time} [sec]') |
| |
|
| | def on_fit_end(self, trainer: Trainer, pl_module: LightningModule): |
| | if self.should_benchmark(trainer): |
| | avg_train_time = sum(self.train_times) / len(self.train_times) |
| | avg_val_time = sum(self.val_times) / len(self.val_times) |
| | avg_train_step_time = sum(self.train_steps_times) / len(self.train_steps_times) |
| | avg_val_step_time = sum(self.val_steps_times) / len(self.val_steps_times) |
| |
|
| | logging.info(f'Average-Training-Epoch-Time: {avg_train_time} [sec]') |
| | logging.info(f'Average-Validation-Epoch-Time: {avg_val_time} [sec]') |
| | logging.info(f'Average-Training-Step-Time: {avg_train_step_time} [sec]') |
| | logging.info(f'Average-Validation-Step-Time: {avg_val_step_time} [sec]') |
| |
|