File size: 6,314 Bytes
18d4089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e411cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00044fd
 
e411cee
 
 
 
00044fd
 
e411cee
 
 
 
 
 
 
 
 
00044fd
 
 
 
 
e411cee
 
 
 
 
 
18d4089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
Custom Lightning Callbacks for TFT-ASRO training.

CurriculumLossScheduler: Gradually shifts loss emphasis from calibration
to directional accuracy as training progresses.

StochasticWeightAveraging: Averages model weights over the last portion
of training to find flatter optima and improve generalisation.

References:
    - Bengio et al. (2009) "Curriculum Learning" (ICML)
    - Izmailov et al. (2018) "Averaging Weights Leads to Wider Optima" (UAI)
"""

from __future__ import annotations

import logging

logger = logging.getLogger(__name__)

try:
    import lightning.pytorch as pl
except ImportError:
    import pytorch_lightning as pl  # type: ignore[no-redef]


class CurriculumLossScheduler(pl.Callback):
    """
    Gradually increase directional loss weight during training.

    Phase 1 (warmup_epochs): Model learns to calibrate — high quantile weight,
        low directional weight.  This establishes correct prediction scale
        before asking the model to learn direction.

    Phase 2 (remaining epochs): Directional components (Sharpe + MADL) are
        linearly ramped up to their target weights, forcing the model to
        learn direction on top of its calibration foundation.

    This prevents the model from being overwhelmed by conflicting gradients
    from calibration, direction, and volatility objectives simultaneously.
    """

    def __init__(
        self,
        warmup_epochs: int = 10,
        initial_lambda_quantile: float = 0.65,
        target_lambda_quantile: float = 0.35,
        initial_lambda_madl: float = 0.05,
        target_lambda_madl: float = 0.25,
    ):
        super().__init__()
        self.warmup_epochs = warmup_epochs
        self.initial_lq = initial_lambda_quantile
        self.target_lq = target_lambda_quantile
        self.initial_madl = initial_lambda_madl
        self.target_madl = target_lambda_madl

    def on_train_epoch_start(self, trainer, pl_module):
        epoch = trainer.current_epoch
        loss = pl_module.loss

        if not hasattr(loss, "lambda_quantile"):
            return

        if epoch < self.warmup_epochs:
            progress = epoch / max(self.warmup_epochs, 1)
            lq = self.initial_lq + (self.target_lq - self.initial_lq) * progress
            lm = self.initial_madl + (self.target_madl - self.initial_madl) * progress
        else:
            lq = self.target_lq
            lm = self.target_madl

        loss.lambda_quantile = lq
        if hasattr(loss, "lambda_madl"):
            loss.lambda_madl = lm

        if epoch % 10 == 0 or epoch == self.warmup_epochs:
            logger.info(
                "Curriculum epoch %d: lambda_quantile=%.3f (w_dir=%.3f) lambda_madl=%.3f",
                epoch, lq, 1.0 - lq, lm,
            )


class WeeklyLossComponentLogger(pl.Callback):
    """Log weekly loss component scales at validation epoch boundaries."""

    def on_validation_epoch_start(self, trainer, pl_module):
        loss = getattr(pl_module, "loss", None)
        if hasattr(loss, "reset_component_accumulators"):
            loss.reset_component_accumulators()

    def on_validation_epoch_end(self, trainer, pl_module):
        loss = getattr(pl_module, "loss", None)
        if not hasattr(loss, "component_means"):
            return

        stats = loss.component_means()
        if not stats.get("n_batches"):
            return

        epoch = getattr(trainer, "current_epoch", 0)
        logger.info(
            "Weekly loss components | epoch=%s weekly_q=%.6f t1_q=%.6f "
            "dispersion=%.6f magnitude=%.6f naive=%.6f directional=%.6f "
            "total=%.6f dominant=%s",
            epoch,
            stats["weekly_q_loss_mean"],
            stats["t1_q_loss_mean"],
            stats["dispersion_loss_mean"],
            stats.get("magnitude_loss_mean", 0.0),
            stats.get("naive_loss_mean", 0.0),
            stats["directional_loss_mean"],
            stats["total_loss_mean"],
            stats["dominant_component"],
        )
        if stats["dispersion_loss_mean"] > 3.0 * max(stats["weekly_q_loss_mean"], 1e-12):
            logger.warning(
                "Weekly dispersion loss is dominating weekly quantile loss; "
                "lambda_dispersion may need to be reduced."
            )
        lambda_directional = float(getattr(loss, "lambda_directional", 0.0))
        directional_is_tiny = (
            stats["directional_loss_mean"] < 0.05 * max(stats["total_loss_mean"], 1e-12)
        )
        if lambda_directional > 0.0 and directional_is_tiny:
            logger.warning(
                "Weekly directional loss is below 5%% of total loss; "
                "lambda_directional may need to increase."
            )


class SWACallback(pl.Callback):
    """
    Stochastic Weight Averaging over the last ``swa_pct`` of training.

    Collects model weights from each epoch after the SWA start point
    and averages them at the end of training, producing a model that
    sits in a flatter region of the loss landscape with better
    generalisation properties.
    """

    def __init__(self, swa_start_pct: float = 0.75):
        super().__init__()
        self.swa_start_pct = swa_start_pct
        self._swa_state: dict | None = None
        self._n_averaged: int = 0

    def on_train_epoch_end(self, trainer, pl_module):
        max_epochs = trainer.max_epochs or 100
        swa_start = int(max_epochs * self.swa_start_pct)

        if trainer.current_epoch < swa_start:
            return

        state = pl_module.state_dict()
        if self._swa_state is None:
            import copy
            self._swa_state = copy.deepcopy(state)
            self._n_averaged = 1
        else:
            self._n_averaged += 1
            for key in self._swa_state:
                self._swa_state[key] = (
                    self._swa_state[key] * (self._n_averaged - 1) + state[key]
                ) / self._n_averaged

    def on_train_end(self, trainer, pl_module):
        if self._swa_state is not None and self._n_averaged > 1:
            pl_module.load_state_dict(self._swa_state)
            logger.info(
                "SWA: averaged %d checkpoints from epoch %d onwards",
                self._n_averaged,
                int((trainer.max_epochs or 100) * self.swa_start_pct),
            )