File size: 5,432 Bytes
87904b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""Rich progress bar callback with epoch-level train/val losses."""
from __future__ import annotations
from typing import Any
import torch
from lightning.pytorch.callbacks import RichProgressBar
class LossLoggingCallback(RichProgressBar):
"""Show epoch-level mean train/val losses directly in Rich progress bar."""
def __init__(self) -> None:
"""Initialize callback state used to aggregate losses per epoch."""
super().__init__()
self._train_loss_sum = 0.0
self._train_loss_count = 0
self._val_loss_sum = 0.0
self._val_loss_count = 0
self._latest_train_loss: float | None = None
self._latest_val_loss: float | None = None
@staticmethod
def _as_float(value: Any) -> float | None:
"""Convert tensor/scalar values to float for logging."""
if value is None:
return None
if torch.is_tensor(value):
return float(value.detach().cpu().item())
if isinstance(value, (int, float)):
return float(value)
return None
@staticmethod
def _extract_loss_from_outputs(outputs: Any) -> float | None:
"""Extract loss scalar from Lightning step outputs."""
if outputs is None:
return None
if isinstance(outputs, dict):
for key in ("loss", "train_loss", "val_loss"):
loss = LossLoggingCallback._as_float(outputs.get(key))
if loss is not None:
return loss
loss = LossLoggingCallback._as_float(outputs)
if loss is not None:
return loss
if hasattr(outputs, "loss"):
return LossLoggingCallback._as_float(outputs.loss)
return None
@staticmethod
def _extract_train_loss_fallback(metrics: dict[str, Any]) -> float | None:
"""Fallback: get train loss from callback metrics."""
for key in ("train_loss_epoch", "train_loss", "train/loss"):
loss = LossLoggingCallback._as_float(metrics.get(key))
if loss is not None:
return loss
return None
@staticmethod
def _extract_val_loss_fallback(metrics: dict[str, Any]) -> float | None:
"""Fallback: get validation loss from callback metrics."""
for key in ("val_loss_epoch", "val_loss", "val/loss"):
loss = LossLoggingCallback._as_float(metrics.get(key))
if loss is not None:
return loss
return None
def on_train_epoch_start(self, trainer, pl_module) -> None:
"""Reset train accumulators at epoch start."""
super().on_train_epoch_start(trainer, pl_module)
self._train_loss_sum = 0.0
self._train_loss_count = 0
def on_validation_epoch_start(self, trainer, pl_module) -> None:
"""Reset validation accumulators at epoch start."""
super().on_validation_epoch_start(trainer, pl_module)
self._val_loss_sum = 0.0
self._val_loss_count = 0
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
"""Accumulate per-batch train loss."""
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
loss = self._extract_loss_from_outputs(outputs)
if loss is None:
return
self._train_loss_sum += loss
self._train_loss_count += 1
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None:
"""Accumulate per-batch validation loss."""
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if trainer.sanity_checking:
return
loss = self._extract_loss_from_outputs(outputs)
if loss is None:
return
self._val_loss_sum += loss
self._val_loss_count += 1
def on_train_epoch_end(self, trainer, pl_module) -> None:
"""Store mean train loss at train epoch end."""
super().on_train_epoch_end(trainer, pl_module)
train_loss = (
self._train_loss_sum / self._train_loss_count
if self._train_loss_count > 0
else None
)
if train_loss is None:
metrics = dict(trainer.callback_metrics)
train_loss = self._extract_train_loss_fallback(metrics)
self._latest_train_loss = train_loss
def on_validation_epoch_end(self, trainer, pl_module) -> None:
"""Store mean val loss at validation epoch end."""
super().on_validation_epoch_end(trainer, pl_module)
if trainer.sanity_checking:
return
val_loss = (
self._val_loss_sum / self._val_loss_count
if self._val_loss_count > 0
else None
)
metrics = dict(trainer.callback_metrics)
if val_loss is None:
val_loss = self._extract_val_loss_fallback(metrics)
self._latest_val_loss = val_loss
def get_metrics(self, trainer, pl_module) -> dict[str, float]:
"""Append epoch-level losses to Rich progress bar metrics."""
metrics = super().get_metrics(trainer, pl_module)
if self._latest_train_loss is not None:
metrics["train_loss"] = round(float(self._latest_train_loss), 6)
if self._latest_val_loss is not None:
metrics["val_loss"] = round(float(self._latest_val_loss), 6)
return metrics
|