"""Rich progress bar callback with epoch-level train/val losses.""" from __future__ import annotations from typing import Any import torch from lightning.pytorch.callbacks import RichProgressBar class LossLoggingCallback(RichProgressBar): """Show epoch-level mean train/val losses directly in Rich progress bar.""" def __init__(self) -> None: """Initialize callback state used to aggregate losses per epoch.""" super().__init__() self._train_loss_sum = 0.0 self._train_loss_count = 0 self._val_loss_sum = 0.0 self._val_loss_count = 0 self._latest_train_loss: float | None = None self._latest_val_loss: float | None = None @staticmethod def _as_float(value: Any) -> float | None: """Convert tensor/scalar values to float for logging.""" if value is None: return None if torch.is_tensor(value): return float(value.detach().cpu().item()) if isinstance(value, (int, float)): return float(value) return None @staticmethod def _extract_loss_from_outputs(outputs: Any) -> float | None: """Extract loss scalar from Lightning step outputs.""" if outputs is None: return None if isinstance(outputs, dict): for key in ("loss", "train_loss", "val_loss"): loss = LossLoggingCallback._as_float(outputs.get(key)) if loss is not None: return loss loss = LossLoggingCallback._as_float(outputs) if loss is not None: return loss if hasattr(outputs, "loss"): return LossLoggingCallback._as_float(outputs.loss) return None @staticmethod def _extract_train_loss_fallback(metrics: dict[str, Any]) -> float | None: """Fallback: get train loss from callback metrics.""" for key in ("train_loss_epoch", "train_loss", "train/loss"): loss = LossLoggingCallback._as_float(metrics.get(key)) if loss is not None: return loss return None @staticmethod def _extract_val_loss_fallback(metrics: dict[str, Any]) -> float | None: """Fallback: get validation loss from callback metrics.""" for key in ("val_loss_epoch", "val_loss", "val/loss"): loss = LossLoggingCallback._as_float(metrics.get(key)) if loss is not None: return loss return None def on_train_epoch_start(self, trainer, pl_module) -> None: """Reset train accumulators at epoch start.""" super().on_train_epoch_start(trainer, pl_module) self._train_loss_sum = 0.0 self._train_loss_count = 0 def on_validation_epoch_start(self, trainer, pl_module) -> None: """Reset validation accumulators at epoch start.""" super().on_validation_epoch_start(trainer, pl_module) self._val_loss_sum = 0.0 self._val_loss_count = 0 def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: """Accumulate per-batch train loss.""" super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) loss = self._extract_loss_from_outputs(outputs) if loss is None: return self._train_loss_sum += loss self._train_loss_count += 1 def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: """Accumulate per-batch validation loss.""" super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if trainer.sanity_checking: return loss = self._extract_loss_from_outputs(outputs) if loss is None: return self._val_loss_sum += loss self._val_loss_count += 1 def on_train_epoch_end(self, trainer, pl_module) -> None: """Store mean train loss at train epoch end.""" super().on_train_epoch_end(trainer, pl_module) train_loss = ( self._train_loss_sum / self._train_loss_count if self._train_loss_count > 0 else None ) if train_loss is None: metrics = dict(trainer.callback_metrics) train_loss = self._extract_train_loss_fallback(metrics) self._latest_train_loss = train_loss def on_validation_epoch_end(self, trainer, pl_module) -> None: """Store mean val loss at validation epoch end.""" super().on_validation_epoch_end(trainer, pl_module) if trainer.sanity_checking: return val_loss = ( self._val_loss_sum / self._val_loss_count if self._val_loss_count > 0 else None ) metrics = dict(trainer.callback_metrics) if val_loss is None: val_loss = self._extract_val_loss_fallback(metrics) self._latest_val_loss = val_loss def get_metrics(self, trainer, pl_module) -> dict[str, float]: """Append epoch-level losses to Rich progress bar metrics.""" metrics = super().get_metrics(trainer, pl_module) if self._latest_train_loss is not None: metrics["train_loss"] = round(float(self._latest_train_loss), 6) if self._latest_val_loss is not None: metrics["val_loss"] = round(float(self._latest_val_loss), 6) return metrics