Spaces:
Running
Running
| """ | |
| TFT-ASRO Model for Copper Futures Prediction. | |
| Wraps pytorch_forecasting's TemporalFusionTransformer with: | |
| - ASRO (Adaptive Sharpe Ratio Optimization) loss | |
| - 7-quantile probabilistic output | |
| - Variable Selection Network for dynamic feature weighting | |
| - Interpretable attention for temporal pattern analysis | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Sequence | |
| import torch | |
| import numpy as np | |
| from deep_learning.config import TFTASROConfig, get_tft_config | |
| from deep_learning.models.losses import AdaptiveSharpeRatioLoss, CombinedQuantileLoss | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Module-level ASRO loss class (must be at module level for pickle / checkpoint) | |
| # --------------------------------------------------------------------------- | |
| try: | |
| from pytorch_forecasting.metrics import QuantileLoss as _PFQuantileLoss | |
| class ASROPFLoss(_PFQuantileLoss): | |
| """ | |
| pytorch_forecasting >= 1.0 compatible ASRO loss. | |
| Inherits from ``QuantileLoss`` (a proper torchmetrics ``Metric``) so | |
| that ``TemporalFusionTransformer.from_dataset()`` accepts it. | |
| Defined at module level so Lightning checkpoints can pickle it. | |
| """ | |
| def __init__( | |
| self, | |
| quantiles: list, | |
| lambda_vol: float = 0.3, | |
| lambda_quantile: float = 0.2, | |
| risk_free_rate: float = 0.0, | |
| sharpe_eps: float = 1e-6, | |
| ): | |
| super().__init__(quantiles=quantiles) | |
| self.lambda_vol = lambda_vol | |
| self.lambda_quantile = lambda_quantile | |
| self.rf = risk_free_rate | |
| self.sharpe_eps = sharpe_eps | |
| self.median_idx = len(quantiles) // 2 | |
| q = list(quantiles) | |
| self._q10_idx = q.index(0.10) if 0.10 in q else 1 | |
| self._q90_idx = q.index(0.90) if 0.90 in q else len(q) - 2 | |
| def loss(self, y_pred: torch.Tensor, target) -> torch.Tensor: # type: ignore[override] | |
| if isinstance(target, (list, tuple)): | |
| y_actual = target[0] | |
| else: | |
| y_actual = target | |
| y_actual = y_actual.float() | |
| median_pred = y_pred[..., self.median_idx] | |
| # Mirrors losses.AdaptiveSharpeRatioLoss exactly. | |
| # scale=20 keeps gradients alive through the full return distribution; | |
| # previous scale=100 saturated above pred=0.015, killing amplitude learning. | |
| _TANH_SCALE = 20.0 | |
| signal = torch.tanh(median_pred * _TANH_SCALE) | |
| strategy_returns = signal * y_actual.float() - self.rf | |
| sharpe_loss = -(strategy_returns.mean() / (strategy_returns.std() + self.sharpe_eps)) | |
| # Volatility calibration: match Q90-Q10 spread to 2× actual σ | |
| pred_spread = ( | |
| y_pred[..., self._q90_idx] - y_pred[..., self._q10_idx] | |
| ).mean() | |
| actual_std = y_actual.std() + self.sharpe_eps | |
| vol_loss = torch.abs(pred_spread - 2.0 * actual_std) | |
| # Median amplitude: penalise if median pred variance < actual variance | |
| median_std = median_pred.std() + self.sharpe_eps | |
| vr = median_std / actual_std | |
| amplitude_loss = ( | |
| torch.relu(1.0 - vr) # under-variance: VR < 1 → strong penalty | |
| + 0.25 * torch.relu(vr - 1.5) # over-variance: VR > 1.5 → gentle penalty | |
| ) | |
| # Quantile (pinball) loss via parent — covers all 7 quantile bands | |
| q_loss = super().loss(y_pred, target) | |
| w_sharpe = 1.0 - self.lambda_quantile | |
| calibration = q_loss + self.lambda_vol * (vol_loss + amplitude_loss) | |
| return self.lambda_quantile * calibration + w_sharpe * sharpe_loss | |
| except ImportError: | |
| ASROPFLoss = None # type: ignore[assignment,misc] | |
| def create_tft_model( | |
| training_dataset, | |
| cfg: Optional[TFTASROConfig] = None, | |
| use_asro: bool = True, | |
| ): | |
| """ | |
| Instantiate a TFT model from a training dataset and config. | |
| Args: | |
| training_dataset: pytorch_forecasting.TimeSeriesDataSet | |
| cfg: TFT-ASRO configuration | |
| use_asro: if True, use ASRO loss; otherwise standard QuantileLoss. | |
| Returns: | |
| TemporalFusionTransformer instance | |
| """ | |
| from pytorch_forecasting import TemporalFusionTransformer | |
| from pytorch_forecasting.metrics import QuantileLoss | |
| if cfg is None: | |
| cfg = get_tft_config() | |
| quantiles = list(cfg.model.quantiles) | |
| if use_asro and ASROPFLoss is not None: | |
| loss = ASROPFLoss( | |
| quantiles=quantiles, | |
| lambda_vol=cfg.asro.lambda_vol, | |
| lambda_quantile=cfg.asro.lambda_quantile, | |
| risk_free_rate=cfg.asro.risk_free_rate, | |
| ) | |
| logger.info( | |
| "Using ASRO loss | w_quantile=%.2f w_sharpe=%.2f lambda_vol=%.2f", | |
| cfg.asro.lambda_quantile, | |
| 1.0 - cfg.asro.lambda_quantile, | |
| cfg.asro.lambda_vol, | |
| ) | |
| else: | |
| loss = QuantileLoss(quantiles=quantiles) | |
| logger.info("Using standard QuantileLoss with %d quantiles", len(quantiles)) | |
| model = TemporalFusionTransformer.from_dataset( | |
| training_dataset, | |
| learning_rate=cfg.model.learning_rate, | |
| hidden_size=cfg.model.hidden_size, | |
| attention_head_size=cfg.model.attention_head_size, | |
| dropout=cfg.model.dropout, | |
| hidden_continuous_size=cfg.model.hidden_continuous_size, | |
| output_size=len(quantiles), | |
| loss=loss, | |
| reduce_on_plateau_patience=cfg.model.reduce_on_plateau_patience, | |
| log_interval=10, | |
| log_val_interval=1, | |
| ) | |
| model.save_hyperparameters(ignore=['loss', 'logging_metrics']) | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| logger.info("TFT model created: %d total params, %d trainable", n_params, n_trainable) | |
| return model | |
| def load_tft_model( | |
| checkpoint_path: str, | |
| map_location: str = "cpu", | |
| ): | |
| """Load a trained TFT model from a Lightning checkpoint.""" | |
| from pytorch_forecasting import TemporalFusionTransformer | |
| path = Path(checkpoint_path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {path}") | |
| model = TemporalFusionTransformer.load_from_checkpoint(str(path), map_location=map_location) | |
| model.eval() | |
| logger.info("Loaded TFT model from %s", path) | |
| return model | |
| # --------------------------------------------------------------------------- | |
| # Interpretation helpers | |
| # --------------------------------------------------------------------------- | |
| def get_variable_importance(model, val_dataloader=None) -> Dict[str, float]: | |
| """ | |
| Extract learned variable importance from the TFT's Variable Selection Networks. | |
| Returns a dict mapping feature name -> normalised importance score. | |
| val_dataloader must be passed explicitly (model.val_dataloader() only works | |
| inside a Lightning Trainer context and raises an error otherwise). | |
| """ | |
| if val_dataloader is None: | |
| return {} | |
| try: | |
| interpretation = model.interpret_output( | |
| model.predict(val_dataloader, return_x=True), | |
| reduction="sum", | |
| ) | |
| importance = interpretation.get("encoder_variables", {}) | |
| if not importance: | |
| return {} | |
| total = sum(importance.values()) | |
| if total == 0: | |
| return importance | |
| return {k: v / total for k, v in sorted(importance.items(), key=lambda x: -x[1])} | |
| except Exception as exc: | |
| logger.warning("Could not extract variable importance: %s", exc) | |
| return {} | |
| def get_attention_weights(model, dataloader) -> Optional[np.ndarray]: | |
| """ | |
| Extract temporal self-attention weights for interpretability. | |
| Returns array of shape (n_samples, n_heads, encoder_length, encoder_length) | |
| or None if extraction fails. | |
| """ | |
| try: | |
| out = model.predict(dataloader, return_x=True, mode="raw") | |
| attn = out.get("attention") | |
| if attn is not None: | |
| return attn.cpu().numpy() | |
| except Exception as exc: | |
| logger.warning("Could not extract attention weights: %s", exc) | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Prediction formatting | |
| # --------------------------------------------------------------------------- | |
| def format_prediction( | |
| raw_prediction: torch.Tensor, | |
| quantiles: Sequence[float] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98), | |
| baseline_price: float = 1.0, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Convert raw TFT quantile output to a structured prediction dict. | |
| Args: | |
| raw_prediction: tensor of shape (prediction_length, n_quantiles) | |
| quantiles: quantile levels | |
| baseline_price: current price for return-to-price conversion | |
| Returns: | |
| Dict with per-day forecasts, confidence bands, and volatility estimate. | |
| Top-level fields use the *final* day (end of horizon) for backward compat. | |
| """ | |
| import math as _math | |
| pred = raw_prediction.cpu().numpy() if isinstance(raw_prediction, torch.Tensor) else raw_prediction | |
| n_days = pred.shape[0] | |
| median_idx = len(quantiles) // 2 | |
| # Guard: log if baseline_price is invalid (NaN prices will be sanitised | |
| # to null by the API layer's _sanitize_floats, keeping the chart clean). | |
| if _math.isnan(baseline_price) or _math.isinf(baseline_price) or baseline_price <= 0: | |
| logger.warning( | |
| "format_prediction: invalid baseline_price=%s — price fields will be null", | |
| baseline_price, | |
| ) | |
| # Hard clamp: prevents overconfident models (VR >> 1) from producing | |
| # absurd compound prices. Copper's actual daily σ ≈ 0.024; capping at | |
| # ~1.25σ keeps the 5-day compound under ≈16 %. The clamp is inactive | |
| # once the model is retrained with a healthy VR (0.5–1.5). | |
| _MAX_DAILY_RET = 0.03 | |
| # T+1 quantile spreads (return-space distance from median). | |
| # Used as the base width for confidence bands; scaled by sqrt(d) for | |
| # later days so uncertainty grows realistically instead of compounding | |
| # tail quantiles exponentially (which would produce absurd bands). | |
| med_0 = float(np.clip(pred[0, median_idx], -_MAX_DAILY_RET, _MAX_DAILY_RET)) | |
| _raw_med_0 = float(pred[0, median_idx]) | |
| spread_q10 = np.clip(float(pred[0, 1]) - _raw_med_0, -_MAX_DAILY_RET, 0) if len(quantiles) > 2 else 0.0 | |
| spread_q90 = np.clip(float(pred[0, -2]) - _raw_med_0, 0, _MAX_DAILY_RET) if len(quantiles) > 2 else 0.0 | |
| spread_q02 = np.clip(float(pred[0, 0]) - _raw_med_0, -_MAX_DAILY_RET * 1.5, 0) | |
| spread_q98 = np.clip(float(pred[0, -1]) - _raw_med_0, 0, _MAX_DAILY_RET * 1.5) | |
| daily_forecasts = [] | |
| cum_price_med = baseline_price | |
| for d in range(n_days): | |
| med = float(np.clip(pred[d, median_idx], -_MAX_DAILY_RET, _MAX_DAILY_RET)) | |
| cum_price_med *= (1 + med) | |
| cum_return = (cum_price_med / baseline_price) - 1.0 | |
| scale = (d + 1) ** 0.5 | |
| daily_forecasts.append({ | |
| "day": d + 1, | |
| "daily_return": med, | |
| "cumulative_return": cum_return, | |
| "price_median": cum_price_med, | |
| "price_q10": cum_price_med * (1 + spread_q10 * scale), | |
| "price_q90": cum_price_med * (1 + spread_q90 * scale), | |
| "price_q02": cum_price_med * (1 + spread_q02 * scale), | |
| "price_q98": cum_price_med * (1 + spread_q98 * scale), | |
| }) | |
| # T+1 is the primary signal (most reliable, highest signal-to-noise). | |
| first = daily_forecasts[0] | |
| last = daily_forecasts[-1] | |
| vol_estimate = (first["price_q90"] - first["price_q10"]) / (2.0 * baseline_price) | |
| return { | |
| "predicted_return_median": first["daily_return"], | |
| "predicted_return_q10": float(np.clip(pred[0, 1], -_MAX_DAILY_RET * 2, _MAX_DAILY_RET * 2)) if len(quantiles) > 2 else first["daily_return"], | |
| "predicted_return_q90": float(np.clip(pred[0, -2], -_MAX_DAILY_RET * 2, _MAX_DAILY_RET * 2)) if len(quantiles) > 2 else first["daily_return"], | |
| "predicted_price_median": first["price_median"], | |
| "predicted_price_q10": first["price_q10"], | |
| "predicted_price_q90": first["price_q90"], | |
| "confidence_band_96": (first["price_q02"], first["price_q98"]), | |
| "volatility_estimate": vol_estimate, | |
| "quantiles": {f"q{q:.2f}": float(pred[0, i]) for i, q in enumerate(quantiles)}, | |
| "weekly_return": last["cumulative_return"], | |
| "weekly_price": last["price_median"], | |
| "prediction_horizon_days": n_days, | |
| "daily_forecasts": daily_forecasts, | |
| } | |