File size: 12,818 Bytes
e57e9d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f6d359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f4b46f
 
 
 
 
 
 
1f6d359
 
 
06d9b9c
52f45ce
 
 
06d9b9c
 
e5838c5
1f6d359
e5838c5
1f6d359
 
 
e5838c5
 
1f6d359
52f45ce
 
a1bedd7
 
 
 
 
52f45ce
06d9b9c
1f6d359
 
06d9b9c
52f45ce
06d9b9c
1f6d359
 
 
5f4b46f
 
e57e9d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f6d359
 
 
 
 
 
 
 
06d9b9c
1f6d359
06d9b9c
 
1f6d359
e57e9d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0b9a02
 
e57e9d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dff0b7c
e57e9d1
 
 
 
dff0b7c
 
e57e9d1
dff0b7c
 
e57e9d1
 
dff0b7c
e57e9d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a76a49
 
e57e9d1
b09a163
 
e57e9d1
9a76a49
e57e9d1
 
f9561ea
 
b09a163
 
f9561ea
b09a163
 
 
622298f
 
 
 
 
 
 
 
 
 
 
225d4a2
 
 
 
 
fec6c24
9a76a49
af722fe
 
9a76a49
622298f
af722fe
 
 
fec6c24
 
9a76a49
 
af722fe
 
 
fec6c24
 
 
 
9a76a49
 
c081d45
 
9a76a49
af722fe
e57e9d1
 
af722fe
225d4a2
 
c081d45
 
 
 
e57e9d1
c081d45
af722fe
c081d45
9a76a49
 
e57e9d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
"""
TFT-ASRO Model for Copper Futures Prediction.

Wraps pytorch_forecasting's TemporalFusionTransformer with:
- ASRO (Adaptive Sharpe Ratio Optimization) loss
- 7-quantile probabilistic output
- Variable Selection Network for dynamic feature weighting
- Interpretable attention for temporal pattern analysis
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any, Dict, Optional, Sequence

import torch
import numpy as np

from deep_learning.config import TFTASROConfig, get_tft_config
from deep_learning.models.losses import AdaptiveSharpeRatioLoss, CombinedQuantileLoss

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Module-level ASRO loss class (must be at module level for pickle / checkpoint)
# ---------------------------------------------------------------------------

try:
    from pytorch_forecasting.metrics import QuantileLoss as _PFQuantileLoss

    class ASROPFLoss(_PFQuantileLoss):
        """
        pytorch_forecasting >= 1.0 compatible ASRO loss.

        Inherits from ``QuantileLoss`` (a proper torchmetrics ``Metric``) so
        that ``TemporalFusionTransformer.from_dataset()`` accepts it.
        Defined at module level so Lightning checkpoints can pickle it.
        """

        def __init__(
            self,
            quantiles: list,
            lambda_vol: float = 0.3,
            lambda_quantile: float = 0.2,
            risk_free_rate: float = 0.0,
            sharpe_eps: float = 1e-6,
        ):
            super().__init__(quantiles=quantiles)
            self.lambda_vol = lambda_vol
            self.lambda_quantile = lambda_quantile
            self.rf = risk_free_rate
            self.sharpe_eps = sharpe_eps
            self.median_idx = len(quantiles) // 2
            q = list(quantiles)
            self._q10_idx = q.index(0.10) if 0.10 in q else 1
            self._q90_idx = q.index(0.90) if 0.90 in q else len(q) - 2

        def loss(self, y_pred: torch.Tensor, target) -> torch.Tensor:  # type: ignore[override]
            if isinstance(target, (list, tuple)):
                y_actual = target[0]
            else:
                y_actual = target

            y_actual = y_actual.float()
            median_pred = y_pred[..., self.median_idx]

            # Mirrors losses.AdaptiveSharpeRatioLoss exactly.
            # scale=20 keeps gradients alive through the full return distribution;
            # previous scale=100 saturated above pred=0.015, killing amplitude learning.
            _TANH_SCALE = 20.0
            signal = torch.tanh(median_pred * _TANH_SCALE)
            strategy_returns = signal * y_actual.float() - self.rf
            sharpe_loss = -(strategy_returns.mean() / (strategy_returns.std() + self.sharpe_eps))

            # Volatility calibration: match Q90-Q10 spread to 2Γ— actual Οƒ
            pred_spread = (
                y_pred[..., self._q90_idx] - y_pred[..., self._q10_idx]
            ).mean()
            actual_std = y_actual.std() + self.sharpe_eps
            vol_loss = torch.abs(pred_spread - 2.0 * actual_std)

            # Median amplitude: penalise if median pred variance < actual variance
            median_std = median_pred.std() + self.sharpe_eps
            vr = median_std / actual_std
            amplitude_loss = (
                torch.relu(1.0 - vr)              # under-variance: VR < 1 β†’ strong penalty
                + 0.25 * torch.relu(vr - 1.5)     # over-variance:  VR > 1.5 β†’ gentle penalty
            )

            # Quantile (pinball) loss via parent β€” covers all 7 quantile bands
            q_loss = super().loss(y_pred, target)

            w_sharpe = 1.0 - self.lambda_quantile
            calibration = q_loss + self.lambda_vol * (vol_loss + amplitude_loss)
            return self.lambda_quantile * calibration + w_sharpe * sharpe_loss

except ImportError:
    ASROPFLoss = None  # type: ignore[assignment,misc]


def create_tft_model(
    training_dataset,
    cfg: Optional[TFTASROConfig] = None,
    use_asro: bool = True,
):
    """
    Instantiate a TFT model from a training dataset and config.

    Args:
        training_dataset: pytorch_forecasting.TimeSeriesDataSet
        cfg: TFT-ASRO configuration
        use_asro: if True, use ASRO loss; otherwise standard QuantileLoss.

    Returns:
        TemporalFusionTransformer instance
    """
    from pytorch_forecasting import TemporalFusionTransformer
    from pytorch_forecasting.metrics import QuantileLoss

    if cfg is None:
        cfg = get_tft_config()

    quantiles = list(cfg.model.quantiles)

    if use_asro and ASROPFLoss is not None:
        loss = ASROPFLoss(
            quantiles=quantiles,
            lambda_vol=cfg.asro.lambda_vol,
            lambda_quantile=cfg.asro.lambda_quantile,
            risk_free_rate=cfg.asro.risk_free_rate,
        )
        logger.info(
            "Using ASRO loss | w_quantile=%.2f w_sharpe=%.2f lambda_vol=%.2f",
            cfg.asro.lambda_quantile,
            1.0 - cfg.asro.lambda_quantile,
            cfg.asro.lambda_vol,
        )
    else:
        loss = QuantileLoss(quantiles=quantiles)
        logger.info("Using standard QuantileLoss with %d quantiles", len(quantiles))

    model = TemporalFusionTransformer.from_dataset(
        training_dataset,
        learning_rate=cfg.model.learning_rate,
        hidden_size=cfg.model.hidden_size,
        attention_head_size=cfg.model.attention_head_size,
        dropout=cfg.model.dropout,
        hidden_continuous_size=cfg.model.hidden_continuous_size,
        output_size=len(quantiles),
        loss=loss,
        reduce_on_plateau_patience=cfg.model.reduce_on_plateau_patience,
        log_interval=10,
        log_val_interval=1,
    )

    model.save_hyperparameters(ignore=['loss', 'logging_metrics'])

    n_params = sum(p.numel() for p in model.parameters())
    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info("TFT model created: %d total params, %d trainable", n_params, n_trainable)

    return model


def load_tft_model(
    checkpoint_path: str,
    map_location: str = "cpu",
):
    """Load a trained TFT model from a Lightning checkpoint."""
    from pytorch_forecasting import TemporalFusionTransformer

    path = Path(checkpoint_path)
    if not path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {path}")

    model = TemporalFusionTransformer.load_from_checkpoint(str(path), map_location=map_location)
    model.eval()
    logger.info("Loaded TFT model from %s", path)
    return model


# ---------------------------------------------------------------------------
# Interpretation helpers
# ---------------------------------------------------------------------------

def get_variable_importance(model, val_dataloader=None) -> Dict[str, float]:
    """
    Extract learned variable importance from the TFT's Variable Selection Networks.

    Returns a dict mapping feature name -> normalised importance score.
    val_dataloader must be passed explicitly (model.val_dataloader() only works
    inside a Lightning Trainer context and raises an error otherwise).
    """
    if val_dataloader is None:
        return {}
    try:
        interpretation = model.interpret_output(
            model.predict(val_dataloader, return_x=True),
            reduction="sum",
        )
        importance = interpretation.get("encoder_variables", {})
        if not importance:
            return {}

        total = sum(importance.values())
        if total == 0:
            return importance

        return {k: v / total for k, v in sorted(importance.items(), key=lambda x: -x[1])}
    except Exception as exc:
        logger.warning("Could not extract variable importance: %s", exc)
        return {}


def get_attention_weights(model, dataloader) -> Optional[np.ndarray]:
    """
    Extract temporal self-attention weights for interpretability.

    Returns array of shape (n_samples, n_heads, encoder_length, encoder_length)
    or None if extraction fails.
    """
    try:
        out = model.predict(dataloader, return_x=True, mode="raw")
        attn = out.get("attention")
        if attn is not None:
            return attn.cpu().numpy()
    except Exception as exc:
        logger.warning("Could not extract attention weights: %s", exc)

    return None


# ---------------------------------------------------------------------------
# Prediction formatting
# ---------------------------------------------------------------------------

def format_prediction(
    raw_prediction: torch.Tensor,
    quantiles: Sequence[float] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98),
    baseline_price: float = 1.0,
) -> Dict[str, Any]:
    """
    Convert raw TFT quantile output to a structured prediction dict.

    Args:
        raw_prediction: tensor of shape (prediction_length, n_quantiles)
        quantiles: quantile levels
        baseline_price: current price for return-to-price conversion

    Returns:
        Dict with per-day forecasts, confidence bands, and volatility estimate.
        Top-level fields use the *final* day (end of horizon) for backward compat.
    """
    import math as _math

    pred = raw_prediction.cpu().numpy() if isinstance(raw_prediction, torch.Tensor) else raw_prediction
    n_days = pred.shape[0]
    median_idx = len(quantiles) // 2

    # Guard: log if baseline_price is invalid (NaN prices will be sanitised
    # to null by the API layer's _sanitize_floats, keeping the chart clean).
    if _math.isnan(baseline_price) or _math.isinf(baseline_price) or baseline_price <= 0:
        logger.warning(
            "format_prediction: invalid baseline_price=%s β€” price fields will be null",
            baseline_price,
        )

    # Hard clamp: prevents overconfident models (VR >> 1) from producing
    # absurd compound prices.  Copper's actual daily Οƒ β‰ˆ 0.024; capping at
    # ~1.25Οƒ keeps the 5-day compound under β‰ˆ16 %.  The clamp is inactive
    # once the model is retrained with a healthy VR (0.5–1.5).
    _MAX_DAILY_RET = 0.03

    # T+1 quantile spreads (return-space distance from median).
    # Used as the base width for confidence bands; scaled by sqrt(d) for
    # later days so uncertainty grows realistically instead of compounding
    # tail quantiles exponentially (which would produce absurd bands).
    med_0 = float(np.clip(pred[0, median_idx], -_MAX_DAILY_RET, _MAX_DAILY_RET))
    _raw_med_0 = float(pred[0, median_idx])
    spread_q10 = np.clip(float(pred[0, 1]) - _raw_med_0, -_MAX_DAILY_RET, 0) if len(quantiles) > 2 else 0.0
    spread_q90 = np.clip(float(pred[0, -2]) - _raw_med_0, 0, _MAX_DAILY_RET) if len(quantiles) > 2 else 0.0
    spread_q02 = np.clip(float(pred[0, 0]) - _raw_med_0, -_MAX_DAILY_RET * 1.5, 0)
    spread_q98 = np.clip(float(pred[0, -1]) - _raw_med_0, 0, _MAX_DAILY_RET * 1.5)

    daily_forecasts = []
    cum_price_med = baseline_price

    for d in range(n_days):
        med = float(np.clip(pred[d, median_idx], -_MAX_DAILY_RET, _MAX_DAILY_RET))
        cum_price_med *= (1 + med)
        cum_return = (cum_price_med / baseline_price) - 1.0

        scale = (d + 1) ** 0.5

        daily_forecasts.append({
            "day": d + 1,
            "daily_return": med,
            "cumulative_return": cum_return,
            "price_median": cum_price_med,
            "price_q10": cum_price_med * (1 + spread_q10 * scale),
            "price_q90": cum_price_med * (1 + spread_q90 * scale),
            "price_q02": cum_price_med * (1 + spread_q02 * scale),
            "price_q98": cum_price_med * (1 + spread_q98 * scale),
        })

    # T+1 is the primary signal (most reliable, highest signal-to-noise).
    first = daily_forecasts[0]
    last = daily_forecasts[-1]
    vol_estimate = (first["price_q90"] - first["price_q10"]) / (2.0 * baseline_price)

    return {
        "predicted_return_median": first["daily_return"],
        "predicted_return_q10": float(np.clip(pred[0, 1], -_MAX_DAILY_RET * 2, _MAX_DAILY_RET * 2)) if len(quantiles) > 2 else first["daily_return"],
        "predicted_return_q90": float(np.clip(pred[0, -2], -_MAX_DAILY_RET * 2, _MAX_DAILY_RET * 2)) if len(quantiles) > 2 else first["daily_return"],
        "predicted_price_median": first["price_median"],
        "predicted_price_q10": first["price_q10"],
        "predicted_price_q90": first["price_q90"],
        "confidence_band_96": (first["price_q02"], first["price_q98"]),
        "volatility_estimate": vol_estimate,
        "quantiles": {f"q{q:.2f}": float(pred[0, i]) for i, q in enumerate(quantiles)},
        "weekly_return": last["cumulative_return"],
        "weekly_price": last["price_median"],
        "prediction_horizon_days": n_days,
        "daily_forecasts": daily_forecasts,
    }