|
|
"""Rich progress bar callback with epoch-level train/val losses.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any |
|
|
|
|
|
import torch |
|
|
from lightning.pytorch.callbacks import RichProgressBar |
|
|
|
|
|
|
|
|
class LossLoggingCallback(RichProgressBar): |
|
|
"""Show epoch-level mean train/val losses directly in Rich progress bar.""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
"""Initialize callback state used to aggregate losses per epoch.""" |
|
|
super().__init__() |
|
|
self._train_loss_sum = 0.0 |
|
|
self._train_loss_count = 0 |
|
|
self._val_loss_sum = 0.0 |
|
|
self._val_loss_count = 0 |
|
|
self._latest_train_loss: float | None = None |
|
|
self._latest_val_loss: float | None = None |
|
|
|
|
|
@staticmethod |
|
|
def _as_float(value: Any) -> float | None: |
|
|
"""Convert tensor/scalar values to float for logging.""" |
|
|
if value is None: |
|
|
return None |
|
|
if torch.is_tensor(value): |
|
|
return float(value.detach().cpu().item()) |
|
|
if isinstance(value, (int, float)): |
|
|
return float(value) |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def _extract_loss_from_outputs(outputs: Any) -> float | None: |
|
|
"""Extract loss scalar from Lightning step outputs.""" |
|
|
if outputs is None: |
|
|
return None |
|
|
if isinstance(outputs, dict): |
|
|
for key in ("loss", "train_loss", "val_loss"): |
|
|
loss = LossLoggingCallback._as_float(outputs.get(key)) |
|
|
if loss is not None: |
|
|
return loss |
|
|
loss = LossLoggingCallback._as_float(outputs) |
|
|
if loss is not None: |
|
|
return loss |
|
|
if hasattr(outputs, "loss"): |
|
|
return LossLoggingCallback._as_float(outputs.loss) |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def _extract_train_loss_fallback(metrics: dict[str, Any]) -> float | None: |
|
|
"""Fallback: get train loss from callback metrics.""" |
|
|
for key in ("train_loss_epoch", "train_loss", "train/loss"): |
|
|
loss = LossLoggingCallback._as_float(metrics.get(key)) |
|
|
if loss is not None: |
|
|
return loss |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def _extract_val_loss_fallback(metrics: dict[str, Any]) -> float | None: |
|
|
"""Fallback: get validation loss from callback metrics.""" |
|
|
for key in ("val_loss_epoch", "val_loss", "val/loss"): |
|
|
loss = LossLoggingCallback._as_float(metrics.get(key)) |
|
|
if loss is not None: |
|
|
return loss |
|
|
return None |
|
|
|
|
|
def on_train_epoch_start(self, trainer, pl_module) -> None: |
|
|
"""Reset train accumulators at epoch start.""" |
|
|
super().on_train_epoch_start(trainer, pl_module) |
|
|
self._train_loss_sum = 0.0 |
|
|
self._train_loss_count = 0 |
|
|
|
|
|
def on_validation_epoch_start(self, trainer, pl_module) -> None: |
|
|
"""Reset validation accumulators at epoch start.""" |
|
|
super().on_validation_epoch_start(trainer, pl_module) |
|
|
self._val_loss_sum = 0.0 |
|
|
self._val_loss_count = 0 |
|
|
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: |
|
|
"""Accumulate per-batch train loss.""" |
|
|
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) |
|
|
loss = self._extract_loss_from_outputs(outputs) |
|
|
if loss is None: |
|
|
return |
|
|
self._train_loss_sum += loss |
|
|
self._train_loss_count += 1 |
|
|
|
|
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: |
|
|
"""Accumulate per-batch validation loss.""" |
|
|
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) |
|
|
if trainer.sanity_checking: |
|
|
return |
|
|
loss = self._extract_loss_from_outputs(outputs) |
|
|
if loss is None: |
|
|
return |
|
|
self._val_loss_sum += loss |
|
|
self._val_loss_count += 1 |
|
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module) -> None: |
|
|
"""Store mean train loss at train epoch end.""" |
|
|
super().on_train_epoch_end(trainer, pl_module) |
|
|
train_loss = ( |
|
|
self._train_loss_sum / self._train_loss_count |
|
|
if self._train_loss_count > 0 |
|
|
else None |
|
|
) |
|
|
if train_loss is None: |
|
|
metrics = dict(trainer.callback_metrics) |
|
|
train_loss = self._extract_train_loss_fallback(metrics) |
|
|
self._latest_train_loss = train_loss |
|
|
|
|
|
def on_validation_epoch_end(self, trainer, pl_module) -> None: |
|
|
"""Store mean val loss at validation epoch end.""" |
|
|
super().on_validation_epoch_end(trainer, pl_module) |
|
|
if trainer.sanity_checking: |
|
|
return |
|
|
|
|
|
val_loss = ( |
|
|
self._val_loss_sum / self._val_loss_count |
|
|
if self._val_loss_count > 0 |
|
|
else None |
|
|
) |
|
|
metrics = dict(trainer.callback_metrics) |
|
|
if val_loss is None: |
|
|
val_loss = self._extract_val_loss_fallback(metrics) |
|
|
self._latest_val_loss = val_loss |
|
|
|
|
|
def get_metrics(self, trainer, pl_module) -> dict[str, float]: |
|
|
"""Append epoch-level losses to Rich progress bar metrics.""" |
|
|
metrics = super().get_metrics(trainer, pl_module) |
|
|
if self._latest_train_loss is not None: |
|
|
metrics["train_loss"] = round(float(self._latest_train_loss), 6) |
|
|
if self._latest_val_loss is not None: |
|
|
metrics["val_loss"] = round(float(self._latest_val_loss), 6) |
|
|
return metrics |
|
|
|