Spaces:
Running
Running
| """ | |
| Custom Loss Functions for TFT-ASRO. | |
| Implements: | |
| - AdaptiveSharpeRatioLoss (ASRO): jointly optimises risk-adjusted return, | |
| volatility calibration, and quantile coverage. | |
| - CombinedQuantileLoss: standard multi-quantile pinball loss used as a | |
| component of ASRO and as a standalone baseline. | |
| """ | |
| from __future__ import annotations | |
| from typing import Optional, Sequence, Union | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from deep_learning.config import ASROConfig | |
| def debug_asro_loss_direction() -> dict: | |
| """ | |
| ASRO kayıp fonksiyonunun matematiksel doğrulaması. | |
| Üç test senaryosu: | |
| 1. correct_direction : tanh(pred) ile actual aynı işaret → loss minimum, Sharpe pozitif | |
| 2. anti_direction : tanh(pred) ile actual ters işaret → loss maksimum, Sharpe negatif | |
| 3. zero_predictions : model sıfır tahmin üretiyor → Sharpe sıfır (dar varyans tuzağı) | |
| Gradyan kontrolleri: | |
| - Her senaryoda grad_norm > 0 olmalı (tanh türevi var, sign() yok) | |
| - Doğru yönde kayıp < sıfır tahmin < ters yön kaybı sırası bozulmamalı | |
| Returns: | |
| { | |
| "passed": bool, | |
| "results": {scenario: {"loss", "grad_norm", "strategy_sharpe"}}, | |
| "diagnostics": str # geçti/kaldı açıklaması | |
| } | |
| """ | |
| import torch | |
| torch.manual_seed(42) | |
| B, T, Q = 64, 5, 7 | |
| quantiles = [0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98] | |
| actual_std = 0.024 | |
| actual = torch.randn(B, T) * actual_std | |
| def _make_preds(median: torch.Tensor) -> torch.Tensor: | |
| """Build a quantile tensor from a given median, spread ≈ 2*actual_std.""" | |
| out = torch.zeros(B, T, Q) | |
| for i, q in enumerate(quantiles): | |
| out[..., i] = median + (q - 0.5) * actual_std * 2 | |
| return out | |
| scenarios = { | |
| "correct_direction": _make_preds(actual * 0.5), | |
| "anti_direction": _make_preds(-actual * 0.5), | |
| "zero_predictions": _make_preds(torch.zeros(B, T)), | |
| } | |
| fn = AdaptiveSharpeRatioLoss(quantiles=quantiles) | |
| results: dict = {} | |
| for name, preds in scenarios.items(): | |
| p = preds.detach().requires_grad_(True) | |
| loss_val = fn(p, actual.detach()) | |
| loss_val.backward() | |
| grad_norm = float(p.grad.norm().item()) if p.grad is not None else 0.0 | |
| with torch.no_grad(): | |
| med = p.detach()[..., len(quantiles) // 2] | |
| signal = torch.tanh(med * 20.0) # same scale as training loss | |
| sr = float( | |
| (signal * actual).mean() / ((signal * actual).std() + 1e-6) | |
| ) | |
| results[name] = { | |
| "loss": round(float(loss_val.item()), 6), | |
| "grad_norm": round(grad_norm, 6), | |
| "strategy_sharpe": round(sr, 4), | |
| } | |
| checks = { | |
| "correct < anti loss": | |
| results["correct_direction"]["loss"] < results["anti_direction"]["loss"], | |
| "correct Sharpe > 0": | |
| results["correct_direction"]["strategy_sharpe"] > 0, | |
| "anti Sharpe < 0": | |
| results["anti_direction"]["strategy_sharpe"] < 0, | |
| "gradients non-zero (correct)": | |
| results["correct_direction"]["grad_norm"] > 1e-6, | |
| "gradients non-zero (anti)": | |
| results["anti_direction"]["grad_norm"] > 1e-6, | |
| } | |
| passed = all(checks.values()) | |
| failed = [k for k, v in checks.items() if not v] | |
| diagnostics = "ALL CHECKS PASSED" if passed else f"FAILED: {failed}" | |
| return {"passed": passed, "results": results, "diagnostics": diagnostics} | |
| class CombinedQuantileLoss(nn.Module): | |
| """ | |
| Multi-quantile pinball loss. | |
| Given K quantile predictions and actual values, the loss is the average | |
| pinball loss across all quantiles and samples. | |
| """ | |
| def __init__(self, quantiles: Sequence[float] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98)): | |
| super().__init__() | |
| self.register_buffer( | |
| "quantiles", | |
| torch.tensor(quantiles, dtype=torch.float32), | |
| ) | |
| def forward( | |
| self, | |
| y_pred: torch.Tensor, | |
| y_actual: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| y_pred: (batch, prediction_length, n_quantiles) | |
| y_actual: (batch, prediction_length) | |
| """ | |
| if y_actual.dim() == 2: | |
| y_actual = y_actual.unsqueeze(-1) | |
| errors = y_actual - y_pred | |
| quantiles = self.quantiles.view(1, 1, -1) | |
| loss = torch.max(quantiles * errors, (quantiles - 1) * errors) | |
| return loss.mean() | |
| class AdaptiveSharpeRatioLoss(nn.Module): | |
| """ | |
| TFT-ASRO loss: combines three objectives to break the low-variance trap. | |
| L = -Sharpe_component | |
| + lambda_vol * volatility_calibration_loss | |
| + lambda_quantile * quantile_coverage_loss | |
| The Sharpe component incentivises the model to produce directionally correct | |
| predictions (not just low MSE), while the volatility term penalises | |
| under-estimation of realised variance, and the quantile term ensures | |
| proper tail coverage. | |
| """ | |
| def __init__( | |
| self, | |
| quantiles: Sequence[float] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98), | |
| lambda_vol: float = 0.3, | |
| lambda_quantile: float = 0.2, | |
| risk_free_rate: float = 0.0, | |
| sharpe_eps: float = 1e-6, | |
| median_idx: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.lambda_vol = lambda_vol | |
| self.lambda_quantile = lambda_quantile | |
| self.rf = risk_free_rate | |
| self.sharpe_eps = sharpe_eps | |
| self.median_idx = median_idx if median_idx is not None else len(quantiles) // 2 | |
| self.quantile_loss = CombinedQuantileLoss(quantiles) | |
| q = list(quantiles) | |
| self._q10_idx = q.index(0.10) if 0.10 in q else 1 | |
| self._q90_idx = q.index(0.90) if 0.90 in q else len(q) - 2 | |
| def forward( | |
| self, | |
| y_pred: torch.Tensor, | |
| y_actual: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| y_pred: (batch, prediction_length, n_quantiles) | |
| y_actual: (batch, prediction_length) | |
| """ | |
| median_pred = y_pred[:, :, self.median_idx] | |
| y_actual_f = y_actual.float() | |
| # --- Sharpe component: tanh soft-sign --- | |
| # Scale chosen so gradients stay alive through the full return distribution. | |
| # Copper daily return std ≈ 0.024; we need sech²(pred*scale) > 0.5 for | |
| # predictions up to ~actual_std, which requires scale*0.024 < 0.66 → scale < 28. | |
| # | |
| # scale=20 gradient profile: | |
| # pred = 0.003 → tanh(0.06) = 0.06, grad = 0.996 (tiny signal → model must grow) | |
| # pred = 0.010 → tanh(0.20) = 0.20, grad = 0.961 (still linear zone) | |
| # pred = 0.024 → tanh(0.48) = 0.45, grad = 0.800 (actual_std level: strong grad) | |
| # pred = 0.050 → tanh(1.00) = 0.76, grad = 0.420 (only saturates at 2× actual) | |
| # | |
| # Previous scale=100 killed gradients above pred=0.015 (sech²=0.18), letting | |
| # the model earn full Sharpe reward from predictions 7× smaller than actual vol. | |
| _TANH_SCALE = 20.0 | |
| signal = torch.tanh(median_pred * _TANH_SCALE) | |
| strategy_returns = signal * y_actual_f - self.rf | |
| sharpe_loss = -(strategy_returns.mean() / (strategy_returns.std() + self.sharpe_eps)) | |
| # --- Volatility calibration --- | |
| # Match Q90-Q10 spread to 2× actual σ so the prediction interval tracks | |
| # realised volatility rather than collapsing to a constant. | |
| pred_spread = (y_pred[:, :, self._q90_idx] - y_pred[:, :, self._q10_idx]).mean() | |
| actual_std = y_actual_f.std() + self.sharpe_eps | |
| vol_loss = torch.abs(pred_spread - 2.0 * actual_std) | |
| # --- Median amplitude penalty --- | |
| # vol_loss only targets the Q10-Q90 band width; the model can widen bands | |
| # while keeping median predictions flat. This term directly penalises the | |
| # median for having lower variance than actual returns. | |
| # relu(1 - VR) fires when pred_std < actual_std; zero otherwise. | |
| median_std = median_pred.std() + self.sharpe_eps | |
| vr = median_std / actual_std | |
| amplitude_loss = ( | |
| torch.relu(1.0 - vr) # under-variance: VR < 1 → strong penalty | |
| + 0.25 * torch.relu(vr - 1.5) # over-variance: VR > 1.5 → gentle penalty | |
| ) | |
| # --- Quantile (pinball) loss --- | |
| q_loss = self.quantile_loss(y_pred, y_actual) | |
| # --- Weighted combination --- | |
| # calibration = quantile bands + band width + median amplitude | |
| # w_quantile + w_sharpe = 1.0 | |
| w_sharpe = 1.0 - self.lambda_quantile | |
| calibration = q_loss + self.lambda_vol * (vol_loss + amplitude_loss) | |
| total = self.lambda_quantile * calibration + w_sharpe * sharpe_loss | |
| return total | |
| def from_config(cls, cfg: ASROConfig, quantiles: Sequence[float]) -> "AdaptiveSharpeRatioLoss": | |
| return cls( | |
| quantiles=quantiles, | |
| lambda_vol=cfg.lambda_vol, | |
| lambda_quantile=cfg.lambda_quantile, | |
| risk_free_rate=cfg.risk_free_rate, | |
| ) | |