File size: 5,432 Bytes
87904b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Rich progress bar callback with epoch-level train/val losses."""

from __future__ import annotations

from typing import Any

import torch
from lightning.pytorch.callbacks import RichProgressBar


class LossLoggingCallback(RichProgressBar):
    """Show epoch-level mean train/val losses directly in Rich progress bar."""

    def __init__(self) -> None:
        """Initialize callback state used to aggregate losses per epoch."""
        super().__init__()
        self._train_loss_sum = 0.0
        self._train_loss_count = 0
        self._val_loss_sum = 0.0
        self._val_loss_count = 0
        self._latest_train_loss: float | None = None
        self._latest_val_loss: float | None = None

    @staticmethod
    def _as_float(value: Any) -> float | None:
        """Convert tensor/scalar values to float for logging."""
        if value is None:
            return None
        if torch.is_tensor(value):
            return float(value.detach().cpu().item())
        if isinstance(value, (int, float)):
            return float(value)
        return None

    @staticmethod
    def _extract_loss_from_outputs(outputs: Any) -> float | None:
        """Extract loss scalar from Lightning step outputs."""
        if outputs is None:
            return None
        if isinstance(outputs, dict):
            for key in ("loss", "train_loss", "val_loss"):
                loss = LossLoggingCallback._as_float(outputs.get(key))
                if loss is not None:
                    return loss
        loss = LossLoggingCallback._as_float(outputs)
        if loss is not None:
            return loss
        if hasattr(outputs, "loss"):
            return LossLoggingCallback._as_float(outputs.loss)
        return None

    @staticmethod
    def _extract_train_loss_fallback(metrics: dict[str, Any]) -> float | None:
        """Fallback: get train loss from callback metrics."""
        for key in ("train_loss_epoch", "train_loss", "train/loss"):
            loss = LossLoggingCallback._as_float(metrics.get(key))
            if loss is not None:
                return loss
        return None

    @staticmethod
    def _extract_val_loss_fallback(metrics: dict[str, Any]) -> float | None:
        """Fallback: get validation loss from callback metrics."""
        for key in ("val_loss_epoch", "val_loss", "val/loss"):
            loss = LossLoggingCallback._as_float(metrics.get(key))
            if loss is not None:
                return loss
        return None

    def on_train_epoch_start(self, trainer, pl_module) -> None:
        """Reset train accumulators at epoch start."""
        super().on_train_epoch_start(trainer, pl_module)
        self._train_loss_sum = 0.0
        self._train_loss_count = 0

    def on_validation_epoch_start(self, trainer, pl_module) -> None:
        """Reset validation accumulators at epoch start."""
        super().on_validation_epoch_start(trainer, pl_module)
        self._val_loss_sum = 0.0
        self._val_loss_count = 0

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
        """Accumulate per-batch train loss."""
        super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
        loss = self._extract_loss_from_outputs(outputs)
        if loss is None:
            return
        self._train_loss_sum += loss
        self._train_loss_count += 1

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None:
        """Accumulate per-batch validation loss."""
        super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
        if trainer.sanity_checking:
            return
        loss = self._extract_loss_from_outputs(outputs)
        if loss is None:
            return
        self._val_loss_sum += loss
        self._val_loss_count += 1

    def on_train_epoch_end(self, trainer, pl_module) -> None:
        """Store mean train loss at train epoch end."""
        super().on_train_epoch_end(trainer, pl_module)
        train_loss = (
            self._train_loss_sum / self._train_loss_count
            if self._train_loss_count > 0
            else None
        )
        if train_loss is None:
            metrics = dict(trainer.callback_metrics)
            train_loss = self._extract_train_loss_fallback(metrics)
        self._latest_train_loss = train_loss

    def on_validation_epoch_end(self, trainer, pl_module) -> None:
        """Store mean val loss at validation epoch end."""
        super().on_validation_epoch_end(trainer, pl_module)
        if trainer.sanity_checking:
            return

        val_loss = (
            self._val_loss_sum / self._val_loss_count
            if self._val_loss_count > 0
            else None
        )
        metrics = dict(trainer.callback_metrics)
        if val_loss is None:
            val_loss = self._extract_val_loss_fallback(metrics)
        self._latest_val_loss = val_loss

    def get_metrics(self, trainer, pl_module) -> dict[str, float]:
        """Append epoch-level losses to Rich progress bar metrics."""
        metrics = super().get_metrics(trainer, pl_module)
        if self._latest_train_loss is not None:
            metrics["train_loss"] = round(float(self._latest_train_loss), 6)
        if self._latest_val_loss is not None:
            metrics["val_loss"] = round(float(self._latest_val_loss), 6)
        return metrics