ifieryarrows's picture
Sync from GitHub (tests passed)
18d4089 verified
"""
Custom Lightning Callbacks for TFT-ASRO training.
CurriculumLossScheduler: Gradually shifts loss emphasis from calibration
to directional accuracy as training progresses.
StochasticWeightAveraging: Averages model weights over the last portion
of training to find flatter optima and improve generalisation.
References:
- Bengio et al. (2009) "Curriculum Learning" (ICML)
- Izmailov et al. (2018) "Averaging Weights Leads to Wider Optima" (UAI)
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
try:
import lightning.pytorch as pl
except ImportError:
import pytorch_lightning as pl # type: ignore[no-redef]
class CurriculumLossScheduler(pl.Callback):
"""
Gradually increase directional loss weight during training.
Phase 1 (warmup_epochs): Model learns to calibrate — high quantile weight,
low directional weight. This establishes correct prediction scale
before asking the model to learn direction.
Phase 2 (remaining epochs): Directional components (Sharpe + MADL) are
linearly ramped up to their target weights, forcing the model to
learn direction on top of its calibration foundation.
This prevents the model from being overwhelmed by conflicting gradients
from calibration, direction, and volatility objectives simultaneously.
"""
def __init__(
self,
warmup_epochs: int = 10,
initial_lambda_quantile: float = 0.65,
target_lambda_quantile: float = 0.35,
initial_lambda_madl: float = 0.05,
target_lambda_madl: float = 0.25,
):
super().__init__()
self.warmup_epochs = warmup_epochs
self.initial_lq = initial_lambda_quantile
self.target_lq = target_lambda_quantile
self.initial_madl = initial_lambda_madl
self.target_madl = target_lambda_madl
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
loss = pl_module.loss
if not hasattr(loss, "lambda_quantile"):
return
if epoch < self.warmup_epochs:
progress = epoch / max(self.warmup_epochs, 1)
lq = self.initial_lq + (self.target_lq - self.initial_lq) * progress
lm = self.initial_madl + (self.target_madl - self.initial_madl) * progress
else:
lq = self.target_lq
lm = self.target_madl
loss.lambda_quantile = lq
if hasattr(loss, "lambda_madl"):
loss.lambda_madl = lm
if epoch % 10 == 0 or epoch == self.warmup_epochs:
logger.info(
"Curriculum epoch %d: lambda_quantile=%.3f (w_dir=%.3f) lambda_madl=%.3f",
epoch, lq, 1.0 - lq, lm,
)
class SWACallback(pl.Callback):
"""
Stochastic Weight Averaging over the last ``swa_pct`` of training.
Collects model weights from each epoch after the SWA start point
and averages them at the end of training, producing a model that
sits in a flatter region of the loss landscape with better
generalisation properties.
"""
def __init__(self, swa_start_pct: float = 0.75):
super().__init__()
self.swa_start_pct = swa_start_pct
self._swa_state: dict | None = None
self._n_averaged: int = 0
def on_train_epoch_end(self, trainer, pl_module):
max_epochs = trainer.max_epochs or 100
swa_start = int(max_epochs * self.swa_start_pct)
if trainer.current_epoch < swa_start:
return
state = pl_module.state_dict()
if self._swa_state is None:
import copy
self._swa_state = copy.deepcopy(state)
self._n_averaged = 1
else:
self._n_averaged += 1
for key in self._swa_state:
self._swa_state[key] = (
self._swa_state[key] * (self._n_averaged - 1) + state[key]
) / self._n_averaged
def on_train_end(self, trainer, pl_module):
if self._swa_state is not None and self._n_averaged > 1:
pl_module.load_state_dict(self._swa_state)
logger.info(
"SWA: averaged %d checkpoints from epoch %d onwards",
self._n_averaged,
int((trainer.max_epochs or 100) * self.swa_start_pct),
)