Spaces:
Running
Running
File size: 6,314 Bytes
18d4089 e411cee 00044fd e411cee 00044fd e411cee 00044fd e411cee 18d4089 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """
Custom Lightning Callbacks for TFT-ASRO training.
CurriculumLossScheduler: Gradually shifts loss emphasis from calibration
to directional accuracy as training progresses.
StochasticWeightAveraging: Averages model weights over the last portion
of training to find flatter optima and improve generalisation.
References:
- Bengio et al. (2009) "Curriculum Learning" (ICML)
- Izmailov et al. (2018) "Averaging Weights Leads to Wider Optima" (UAI)
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
try:
import lightning.pytorch as pl
except ImportError:
import pytorch_lightning as pl # type: ignore[no-redef]
class CurriculumLossScheduler(pl.Callback):
"""
Gradually increase directional loss weight during training.
Phase 1 (warmup_epochs): Model learns to calibrate — high quantile weight,
low directional weight. This establishes correct prediction scale
before asking the model to learn direction.
Phase 2 (remaining epochs): Directional components (Sharpe + MADL) are
linearly ramped up to their target weights, forcing the model to
learn direction on top of its calibration foundation.
This prevents the model from being overwhelmed by conflicting gradients
from calibration, direction, and volatility objectives simultaneously.
"""
def __init__(
self,
warmup_epochs: int = 10,
initial_lambda_quantile: float = 0.65,
target_lambda_quantile: float = 0.35,
initial_lambda_madl: float = 0.05,
target_lambda_madl: float = 0.25,
):
super().__init__()
self.warmup_epochs = warmup_epochs
self.initial_lq = initial_lambda_quantile
self.target_lq = target_lambda_quantile
self.initial_madl = initial_lambda_madl
self.target_madl = target_lambda_madl
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
loss = pl_module.loss
if not hasattr(loss, "lambda_quantile"):
return
if epoch < self.warmup_epochs:
progress = epoch / max(self.warmup_epochs, 1)
lq = self.initial_lq + (self.target_lq - self.initial_lq) * progress
lm = self.initial_madl + (self.target_madl - self.initial_madl) * progress
else:
lq = self.target_lq
lm = self.target_madl
loss.lambda_quantile = lq
if hasattr(loss, "lambda_madl"):
loss.lambda_madl = lm
if epoch % 10 == 0 or epoch == self.warmup_epochs:
logger.info(
"Curriculum epoch %d: lambda_quantile=%.3f (w_dir=%.3f) lambda_madl=%.3f",
epoch, lq, 1.0 - lq, lm,
)
class WeeklyLossComponentLogger(pl.Callback):
"""Log weekly loss component scales at validation epoch boundaries."""
def on_validation_epoch_start(self, trainer, pl_module):
loss = getattr(pl_module, "loss", None)
if hasattr(loss, "reset_component_accumulators"):
loss.reset_component_accumulators()
def on_validation_epoch_end(self, trainer, pl_module):
loss = getattr(pl_module, "loss", None)
if not hasattr(loss, "component_means"):
return
stats = loss.component_means()
if not stats.get("n_batches"):
return
epoch = getattr(trainer, "current_epoch", 0)
logger.info(
"Weekly loss components | epoch=%s weekly_q=%.6f t1_q=%.6f "
"dispersion=%.6f magnitude=%.6f naive=%.6f directional=%.6f "
"total=%.6f dominant=%s",
epoch,
stats["weekly_q_loss_mean"],
stats["t1_q_loss_mean"],
stats["dispersion_loss_mean"],
stats.get("magnitude_loss_mean", 0.0),
stats.get("naive_loss_mean", 0.0),
stats["directional_loss_mean"],
stats["total_loss_mean"],
stats["dominant_component"],
)
if stats["dispersion_loss_mean"] > 3.0 * max(stats["weekly_q_loss_mean"], 1e-12):
logger.warning(
"Weekly dispersion loss is dominating weekly quantile loss; "
"lambda_dispersion may need to be reduced."
)
lambda_directional = float(getattr(loss, "lambda_directional", 0.0))
directional_is_tiny = (
stats["directional_loss_mean"] < 0.05 * max(stats["total_loss_mean"], 1e-12)
)
if lambda_directional > 0.0 and directional_is_tiny:
logger.warning(
"Weekly directional loss is below 5%% of total loss; "
"lambda_directional may need to increase."
)
class SWACallback(pl.Callback):
"""
Stochastic Weight Averaging over the last ``swa_pct`` of training.
Collects model weights from each epoch after the SWA start point
and averages them at the end of training, producing a model that
sits in a flatter region of the loss landscape with better
generalisation properties.
"""
def __init__(self, swa_start_pct: float = 0.75):
super().__init__()
self.swa_start_pct = swa_start_pct
self._swa_state: dict | None = None
self._n_averaged: int = 0
def on_train_epoch_end(self, trainer, pl_module):
max_epochs = trainer.max_epochs or 100
swa_start = int(max_epochs * self.swa_start_pct)
if trainer.current_epoch < swa_start:
return
state = pl_module.state_dict()
if self._swa_state is None:
import copy
self._swa_state = copy.deepcopy(state)
self._n_averaged = 1
else:
self._n_averaged += 1
for key in self._swa_state:
self._swa_state[key] = (
self._swa_state[key] * (self._n_averaged - 1) + state[key]
) / self._n_averaged
def on_train_end(self, trainer, pl_module):
if self._swa_state is not None and self._n_averaged > 1:
pl_module.load_state_dict(self._swa_state)
logger.info(
"SWA: averaged %d checkpoints from epoch %d onwards",
self._n_averaged,
int((trainer.max_epochs or 100) * self.swa_start_pct),
)
|