WildDet3D / vis4d /engine /callbacks /logging.py
weikaih's picture
WildDet3D Gradio demo
f71ac1d verified
"""This module contains utilities for callbacks."""
from __future__ import annotations
from collections import defaultdict
from typing import Any
import lightning.pytorch as pl
from vis4d.common.logging import rank_zero_info
from vis4d.common.progress import compose_log_str
from vis4d.common.time import Timer
from vis4d.common.typing import ArgsType, MetricLogs
from .base import Callback
class LoggingCallback(Callback):
"""Callback for logging."""
def __init__(
self, *args: ArgsType, refresh_rate: int = 50, **kwargs: ArgsType
) -> None:
"""Init callback."""
super().__init__(*args, **kwargs)
self._refresh_rate = refresh_rate
self._metrics: dict[str, list[float]] = defaultdict(list)
self.train_timer = Timer()
self.test_timer = Timer()
self.last_step = 0
def on_train_epoch_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
"""Hook to run at the start of a training epoch."""
if self.epoch_based:
self.train_timer.reset()
self.last_step = 0
self._metrics.clear()
elif trainer.global_step == 0:
self.train_timer.reset()
def on_train_batch_start( # type: ignore
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
batch: Any,
batch_idx: int,
) -> None:
"""Hook to run at the start of a training batch."""
if self.train_timer.paused:
self.train_timer.resume()
def on_train_batch_end( # type: ignore
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
) -> None:
"""Hook to run at the end of a training batch."""
if "metrics" in outputs:
for k, v in outputs["metrics"].items():
self._metrics[k].append(v)
if self.epoch_based:
cur_iter = batch_idx + 1
# Resolve float("inf") to -1
if isinstance(trainer.num_training_batches, float):
total_iters = -1
else:
total_iters = trainer.num_training_batches
else:
cur_iter = trainer.global_step + 1
total_iters = trainer.max_steps
if cur_iter % self._refresh_rate == 0 and cur_iter != self.last_step:
prefix = (
f"Epoch {pl_module.current_epoch + 1}"
if self.epoch_based
else "Iter"
)
log_dict: MetricLogs = {
k: sum(v) / len(v) if len(v) > 0 else float("NaN")
for k, v in self._metrics.items()
}
rank_zero_info(
compose_log_str(
prefix, cur_iter, total_iters, self.train_timer, log_dict
)
)
self._metrics.clear()
self.last_step = cur_iter
for k, v in log_dict.items():
pl_module.log(f"train/{k}", v, rank_zero_only=True)
def on_validation_epoch_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
"""Hook to run at the start of a validation epoch."""
self.test_timer.reset()
self.train_timer.pause()
def on_validation_batch_end( # type: ignore
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Wait for on_validation_batch_end PL hook to call 'process'."""
cur_iter = batch_idx + 1
# Resolve float("inf") to -1
if isinstance(trainer.num_val_batches[dataloader_idx], int):
total_iters = int(trainer.num_val_batches[dataloader_idx])
else:
total_iters = -1
if cur_iter % self._refresh_rate == 0:
rank_zero_info(
compose_log_str(
"Validation", cur_iter, total_iters, self.test_timer
)
)
def on_test_epoch_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
"""Hook to run at the start of a testing epoch."""
self.test_timer.reset()
self.train_timer.pause()
def on_test_batch_end( # type: ignore
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Hook to run at the end of a testing batch."""
cur_iter = batch_idx + 1
# Resolve float("inf") to -1
if isinstance(trainer.num_test_batches[dataloader_idx], int):
total_iters = int(trainer.num_test_batches[dataloader_idx])
else:
total_iters = -1
if cur_iter % self._refresh_rate == 0:
rank_zero_info(
compose_log_str(
"Testing", cur_iter, total_iters, self.test_timer
)
)