| """This module contains utilities for callbacks.""" |
|
|
| from __future__ import annotations |
|
|
| from collections import defaultdict |
| from typing import Any |
|
|
| import lightning.pytorch as pl |
|
|
| from vis4d.common.logging import rank_zero_info |
| from vis4d.common.progress import compose_log_str |
| from vis4d.common.time import Timer |
| from vis4d.common.typing import ArgsType, MetricLogs |
|
|
| from .base import Callback |
|
|
|
|
| class LoggingCallback(Callback): |
| """Callback for logging.""" |
|
|
| def __init__( |
| self, *args: ArgsType, refresh_rate: int = 50, **kwargs: ArgsType |
| ) -> None: |
| """Init callback.""" |
| super().__init__(*args, **kwargs) |
| self._refresh_rate = refresh_rate |
| self._metrics: dict[str, list[float]] = defaultdict(list) |
| self.train_timer = Timer() |
| self.test_timer = Timer() |
| self.last_step = 0 |
|
|
| def on_train_epoch_start( |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule |
| ) -> None: |
| """Hook to run at the start of a training epoch.""" |
| if self.epoch_based: |
| self.train_timer.reset() |
| self.last_step = 0 |
| self._metrics.clear() |
| elif trainer.global_step == 0: |
| self.train_timer.reset() |
|
|
| def on_train_batch_start( |
| self, |
| trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| batch: Any, |
| batch_idx: int, |
| ) -> None: |
| """Hook to run at the start of a training batch.""" |
| if self.train_timer.paused: |
| self.train_timer.resume() |
|
|
| def on_train_batch_end( |
| self, |
| trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| outputs: Any, |
| batch: Any, |
| batch_idx: int, |
| ) -> None: |
| """Hook to run at the end of a training batch.""" |
| if "metrics" in outputs: |
| for k, v in outputs["metrics"].items(): |
| self._metrics[k].append(v) |
|
|
| if self.epoch_based: |
| cur_iter = batch_idx + 1 |
|
|
| |
| if isinstance(trainer.num_training_batches, float): |
| total_iters = -1 |
| else: |
| total_iters = trainer.num_training_batches |
| else: |
| cur_iter = trainer.global_step + 1 |
| total_iters = trainer.max_steps |
|
|
| if cur_iter % self._refresh_rate == 0 and cur_iter != self.last_step: |
| prefix = ( |
| f"Epoch {pl_module.current_epoch + 1}" |
| if self.epoch_based |
| else "Iter" |
| ) |
|
|
| log_dict: MetricLogs = { |
| k: sum(v) / len(v) if len(v) > 0 else float("NaN") |
| for k, v in self._metrics.items() |
| } |
|
|
| rank_zero_info( |
| compose_log_str( |
| prefix, cur_iter, total_iters, self.train_timer, log_dict |
| ) |
| ) |
|
|
| self._metrics.clear() |
| self.last_step = cur_iter |
|
|
| for k, v in log_dict.items(): |
| pl_module.log(f"train/{k}", v, rank_zero_only=True) |
|
|
| def on_validation_epoch_start( |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule |
| ) -> None: |
| """Hook to run at the start of a validation epoch.""" |
| self.test_timer.reset() |
| self.train_timer.pause() |
|
|
| def on_validation_batch_end( |
| self, |
| trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| outputs: Any, |
| batch: Any, |
| batch_idx: int, |
| dataloader_idx: int = 0, |
| ) -> None: |
| """Wait for on_validation_batch_end PL hook to call 'process'.""" |
| cur_iter = batch_idx + 1 |
|
|
| |
| if isinstance(trainer.num_val_batches[dataloader_idx], int): |
| total_iters = int(trainer.num_val_batches[dataloader_idx]) |
| else: |
| total_iters = -1 |
|
|
| if cur_iter % self._refresh_rate == 0: |
| rank_zero_info( |
| compose_log_str( |
| "Validation", cur_iter, total_iters, self.test_timer |
| ) |
| ) |
|
|
| def on_test_epoch_start( |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule |
| ) -> None: |
| """Hook to run at the start of a testing epoch.""" |
| self.test_timer.reset() |
| self.train_timer.pause() |
|
|
| def on_test_batch_end( |
| self, |
| trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| outputs: Any, |
| batch: Any, |
| batch_idx: int, |
| dataloader_idx: int = 0, |
| ) -> None: |
| """Hook to run at the end of a testing batch.""" |
| cur_iter = batch_idx + 1 |
|
|
| |
| if isinstance(trainer.num_test_batches[dataloader_idx], int): |
| total_iters = int(trainer.num_test_batches[dataloader_idx]) |
| else: |
| total_iters = -1 |
|
|
| if cur_iter % self._refresh_rate == 0: |
| rank_zero_info( |
| compose_log_str( |
| "Testing", cur_iter, total_iters, self.test_timer |
| ) |
| ) |
|
|