TerraMind-HYPERVIEW / callback_hooks /loss_logging_callback.py
KPLabs's picture
Upload folder using huggingface_hub
87904b0 verified
"""Rich progress bar callback with epoch-level train/val losses."""
from __future__ import annotations
from typing import Any
import torch
from lightning.pytorch.callbacks import RichProgressBar
class LossLoggingCallback(RichProgressBar):
"""Show epoch-level mean train/val losses directly in Rich progress bar."""
def __init__(self) -> None:
"""Initialize callback state used to aggregate losses per epoch."""
super().__init__()
self._train_loss_sum = 0.0
self._train_loss_count = 0
self._val_loss_sum = 0.0
self._val_loss_count = 0
self._latest_train_loss: float | None = None
self._latest_val_loss: float | None = None
@staticmethod
def _as_float(value: Any) -> float | None:
"""Convert tensor/scalar values to float for logging."""
if value is None:
return None
if torch.is_tensor(value):
return float(value.detach().cpu().item())
if isinstance(value, (int, float)):
return float(value)
return None
@staticmethod
def _extract_loss_from_outputs(outputs: Any) -> float | None:
"""Extract loss scalar from Lightning step outputs."""
if outputs is None:
return None
if isinstance(outputs, dict):
for key in ("loss", "train_loss", "val_loss"):
loss = LossLoggingCallback._as_float(outputs.get(key))
if loss is not None:
return loss
loss = LossLoggingCallback._as_float(outputs)
if loss is not None:
return loss
if hasattr(outputs, "loss"):
return LossLoggingCallback._as_float(outputs.loss)
return None
@staticmethod
def _extract_train_loss_fallback(metrics: dict[str, Any]) -> float | None:
"""Fallback: get train loss from callback metrics."""
for key in ("train_loss_epoch", "train_loss", "train/loss"):
loss = LossLoggingCallback._as_float(metrics.get(key))
if loss is not None:
return loss
return None
@staticmethod
def _extract_val_loss_fallback(metrics: dict[str, Any]) -> float | None:
"""Fallback: get validation loss from callback metrics."""
for key in ("val_loss_epoch", "val_loss", "val/loss"):
loss = LossLoggingCallback._as_float(metrics.get(key))
if loss is not None:
return loss
return None
def on_train_epoch_start(self, trainer, pl_module) -> None:
"""Reset train accumulators at epoch start."""
super().on_train_epoch_start(trainer, pl_module)
self._train_loss_sum = 0.0
self._train_loss_count = 0
def on_validation_epoch_start(self, trainer, pl_module) -> None:
"""Reset validation accumulators at epoch start."""
super().on_validation_epoch_start(trainer, pl_module)
self._val_loss_sum = 0.0
self._val_loss_count = 0
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
"""Accumulate per-batch train loss."""
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
loss = self._extract_loss_from_outputs(outputs)
if loss is None:
return
self._train_loss_sum += loss
self._train_loss_count += 1
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None:
"""Accumulate per-batch validation loss."""
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if trainer.sanity_checking:
return
loss = self._extract_loss_from_outputs(outputs)
if loss is None:
return
self._val_loss_sum += loss
self._val_loss_count += 1
def on_train_epoch_end(self, trainer, pl_module) -> None:
"""Store mean train loss at train epoch end."""
super().on_train_epoch_end(trainer, pl_module)
train_loss = (
self._train_loss_sum / self._train_loss_count
if self._train_loss_count > 0
else None
)
if train_loss is None:
metrics = dict(trainer.callback_metrics)
train_loss = self._extract_train_loss_fallback(metrics)
self._latest_train_loss = train_loss
def on_validation_epoch_end(self, trainer, pl_module) -> None:
"""Store mean val loss at validation epoch end."""
super().on_validation_epoch_end(trainer, pl_module)
if trainer.sanity_checking:
return
val_loss = (
self._val_loss_sum / self._val_loss_count
if self._val_loss_count > 0
else None
)
metrics = dict(trainer.callback_metrics)
if val_loss is None:
val_loss = self._extract_val_loss_fallback(metrics)
self._latest_val_loss = val_loss
def get_metrics(self, trainer, pl_module) -> dict[str, float]:
"""Append epoch-level losses to Rich progress bar metrics."""
metrics = super().get_metrics(trainer, pl_module)
if self._latest_train_loss is not None:
metrics["train_loss"] = round(float(self._latest_train_loss), 6)
if self._latest_val_loss is not None:
metrics["val_loss"] = round(float(self._latest_val_loss), 6)
return metrics