Spaces:
Running
Running
| """ | |
| Custom Lightning Callbacks for TFT-ASRO training. | |
| CurriculumLossScheduler: Gradually shifts loss emphasis from calibration | |
| to directional accuracy as training progresses. | |
| StochasticWeightAveraging: Averages model weights over the last portion | |
| of training to find flatter optima and improve generalisation. | |
| References: | |
| - Bengio et al. (2009) "Curriculum Learning" (ICML) | |
| - Izmailov et al. (2018) "Averaging Weights Leads to Wider Optima" (UAI) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| try: | |
| import lightning.pytorch as pl | |
| except ImportError: | |
| import pytorch_lightning as pl # type: ignore[no-redef] | |
| class CurriculumLossScheduler(pl.Callback): | |
| """ | |
| Gradually increase directional loss weight during training. | |
| Phase 1 (warmup_epochs): Model learns to calibrate — high quantile weight, | |
| low directional weight. This establishes correct prediction scale | |
| before asking the model to learn direction. | |
| Phase 2 (remaining epochs): Directional components (Sharpe + MADL) are | |
| linearly ramped up to their target weights, forcing the model to | |
| learn direction on top of its calibration foundation. | |
| This prevents the model from being overwhelmed by conflicting gradients | |
| from calibration, direction, and volatility objectives simultaneously. | |
| """ | |
| def __init__( | |
| self, | |
| warmup_epochs: int = 10, | |
| initial_lambda_quantile: float = 0.65, | |
| target_lambda_quantile: float = 0.35, | |
| initial_lambda_madl: float = 0.05, | |
| target_lambda_madl: float = 0.25, | |
| ): | |
| super().__init__() | |
| self.warmup_epochs = warmup_epochs | |
| self.initial_lq = initial_lambda_quantile | |
| self.target_lq = target_lambda_quantile | |
| self.initial_madl = initial_lambda_madl | |
| self.target_madl = target_lambda_madl | |
| def on_train_epoch_start(self, trainer, pl_module): | |
| epoch = trainer.current_epoch | |
| loss = pl_module.loss | |
| if not hasattr(loss, "lambda_quantile"): | |
| return | |
| if epoch < self.warmup_epochs: | |
| progress = epoch / max(self.warmup_epochs, 1) | |
| lq = self.initial_lq + (self.target_lq - self.initial_lq) * progress | |
| lm = self.initial_madl + (self.target_madl - self.initial_madl) * progress | |
| else: | |
| lq = self.target_lq | |
| lm = self.target_madl | |
| loss.lambda_quantile = lq | |
| if hasattr(loss, "lambda_madl"): | |
| loss.lambda_madl = lm | |
| if epoch % 10 == 0 or epoch == self.warmup_epochs: | |
| logger.info( | |
| "Curriculum epoch %d: lambda_quantile=%.3f (w_dir=%.3f) lambda_madl=%.3f", | |
| epoch, lq, 1.0 - lq, lm, | |
| ) | |
| class SWACallback(pl.Callback): | |
| """ | |
| Stochastic Weight Averaging over the last ``swa_pct`` of training. | |
| Collects model weights from each epoch after the SWA start point | |
| and averages them at the end of training, producing a model that | |
| sits in a flatter region of the loss landscape with better | |
| generalisation properties. | |
| """ | |
| def __init__(self, swa_start_pct: float = 0.75): | |
| super().__init__() | |
| self.swa_start_pct = swa_start_pct | |
| self._swa_state: dict | None = None | |
| self._n_averaged: int = 0 | |
| def on_train_epoch_end(self, trainer, pl_module): | |
| max_epochs = trainer.max_epochs or 100 | |
| swa_start = int(max_epochs * self.swa_start_pct) | |
| if trainer.current_epoch < swa_start: | |
| return | |
| state = pl_module.state_dict() | |
| if self._swa_state is None: | |
| import copy | |
| self._swa_state = copy.deepcopy(state) | |
| self._n_averaged = 1 | |
| else: | |
| self._n_averaged += 1 | |
| for key in self._swa_state: | |
| self._swa_state[key] = ( | |
| self._swa_state[key] * (self._n_averaged - 1) + state[key] | |
| ) / self._n_averaged | |
| def on_train_end(self, trainer, pl_module): | |
| if self._swa_state is not None and self._n_averaged > 1: | |
| pl_module.load_state_dict(self._swa_state) | |
| logger.info( | |
| "SWA: averaged %d checkpoints from epoch %d onwards", | |
| self._n_averaged, | |
| int((trainer.max_epochs or 100) * self.swa_start_pct), | |
| ) | |