""" Custom Loss Functions for TFT-ASRO. Implements: - AdaptiveSharpeRatioLoss (ASRO): jointly optimises risk-adjusted return, volatility calibration, and quantile coverage. - CombinedQuantileLoss: standard multi-quantile pinball loss used as a component of ASRO and as a standalone baseline. """ from __future__ import annotations from typing import Optional, Sequence, Union import torch import torch.nn as nn import numpy as np from deep_learning.config import ASROConfig def debug_asro_loss_direction() -> dict: """ ASRO kayıp fonksiyonunun matematiksel doğrulaması. Üç test senaryosu: 1. correct_direction : tanh(pred) ile actual aynı işaret → loss minimum, Sharpe pozitif 2. anti_direction : tanh(pred) ile actual ters işaret → loss maksimum, Sharpe negatif 3. zero_predictions : model sıfır tahmin üretiyor → Sharpe sıfır (dar varyans tuzağı) Gradyan kontrolleri: - Her senaryoda grad_norm > 0 olmalı (tanh türevi var, sign() yok) - Doğru yönde kayıp < sıfır tahmin < ters yön kaybı sırası bozulmamalı Returns: { "passed": bool, "results": {scenario: {"loss", "grad_norm", "strategy_sharpe"}}, "diagnostics": str # geçti/kaldı açıklaması } """ import torch torch.manual_seed(42) B, T, Q = 64, 5, 7 quantiles = [0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98] actual_std = 0.024 actual = torch.randn(B, T) * actual_std def _make_preds(median: torch.Tensor) -> torch.Tensor: """Build a quantile tensor from a given median, spread ≈ 2*actual_std.""" out = torch.zeros(B, T, Q) for i, q in enumerate(quantiles): out[..., i] = median + (q - 0.5) * actual_std * 2 return out scenarios = { "correct_direction": _make_preds(actual * 0.5), "anti_direction": _make_preds(-actual * 0.5), "zero_predictions": _make_preds(torch.zeros(B, T)), } fn = AdaptiveSharpeRatioLoss(quantiles=quantiles) results: dict = {} for name, preds in scenarios.items(): p = preds.detach().requires_grad_(True) loss_val = fn(p, actual.detach()) loss_val.backward() grad_norm = float(p.grad.norm().item()) if p.grad is not None else 0.0 with torch.no_grad(): med = p.detach()[..., len(quantiles) // 2] signal = torch.tanh(med * 20.0) # same scale as training loss sr = float( (signal * actual).mean() / ((signal * actual).std() + 1e-6) ) results[name] = { "loss": round(float(loss_val.item()), 6), "grad_norm": round(grad_norm, 6), "strategy_sharpe": round(sr, 4), } checks = { "correct < anti loss": results["correct_direction"]["loss"] < results["anti_direction"]["loss"], "correct Sharpe > 0": results["correct_direction"]["strategy_sharpe"] > 0, "anti Sharpe < 0": results["anti_direction"]["strategy_sharpe"] < 0, "gradients non-zero (correct)": results["correct_direction"]["grad_norm"] > 1e-6, "gradients non-zero (anti)": results["anti_direction"]["grad_norm"] > 1e-6, } passed = all(checks.values()) failed = [k for k, v in checks.items() if not v] diagnostics = "ALL CHECKS PASSED" if passed else f"FAILED: {failed}" return {"passed": passed, "results": results, "diagnostics": diagnostics} class CombinedQuantileLoss(nn.Module): """ Multi-quantile pinball loss. Given K quantile predictions and actual values, the loss is the average pinball loss across all quantiles and samples. """ def __init__(self, quantiles: Sequence[float] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98)): super().__init__() self.register_buffer( "quantiles", torch.tensor(quantiles, dtype=torch.float32), ) def forward( self, y_pred: torch.Tensor, y_actual: torch.Tensor, ) -> torch.Tensor: """ Args: y_pred: (batch, prediction_length, n_quantiles) y_actual: (batch, prediction_length) """ if y_actual.dim() == 2: y_actual = y_actual.unsqueeze(-1) errors = y_actual - y_pred quantiles = self.quantiles.view(1, 1, -1) loss = torch.max(quantiles * errors, (quantiles - 1) * errors) return loss.mean() class AdaptiveSharpeRatioLoss(nn.Module): """ TFT-ASRO loss: combines three objectives to break the low-variance trap. L = -Sharpe_component + lambda_vol * volatility_calibration_loss + lambda_quantile * quantile_coverage_loss The Sharpe component incentivises the model to produce directionally correct predictions (not just low MSE), while the volatility term penalises under-estimation of realised variance, and the quantile term ensures proper tail coverage. """ def __init__( self, quantiles: Sequence[float] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98), lambda_vol: float = 0.3, lambda_quantile: float = 0.2, risk_free_rate: float = 0.0, sharpe_eps: float = 1e-6, median_idx: Optional[int] = None, ): super().__init__() self.lambda_vol = lambda_vol self.lambda_quantile = lambda_quantile self.rf = risk_free_rate self.sharpe_eps = sharpe_eps self.median_idx = median_idx if median_idx is not None else len(quantiles) // 2 self.quantile_loss = CombinedQuantileLoss(quantiles) q = list(quantiles) self._q10_idx = q.index(0.10) if 0.10 in q else 1 self._q90_idx = q.index(0.90) if 0.90 in q else len(q) - 2 def forward( self, y_pred: torch.Tensor, y_actual: torch.Tensor, ) -> torch.Tensor: """ Args: y_pred: (batch, prediction_length, n_quantiles) y_actual: (batch, prediction_length) """ median_pred = y_pred[:, :, self.median_idx] y_actual_f = y_actual.float() # --- Sharpe component: tanh soft-sign --- # Scale chosen so gradients stay alive through the full return distribution. # Copper daily return std ≈ 0.024; we need sech²(pred*scale) > 0.5 for # predictions up to ~actual_std, which requires scale*0.024 < 0.66 → scale < 28. # # scale=20 gradient profile: # pred = 0.003 → tanh(0.06) = 0.06, grad = 0.996 (tiny signal → model must grow) # pred = 0.010 → tanh(0.20) = 0.20, grad = 0.961 (still linear zone) # pred = 0.024 → tanh(0.48) = 0.45, grad = 0.800 (actual_std level: strong grad) # pred = 0.050 → tanh(1.00) = 0.76, grad = 0.420 (only saturates at 2× actual) # # Previous scale=100 killed gradients above pred=0.015 (sech²=0.18), letting # the model earn full Sharpe reward from predictions 7× smaller than actual vol. _TANH_SCALE = 20.0 signal = torch.tanh(median_pred * _TANH_SCALE) strategy_returns = signal * y_actual_f - self.rf sharpe_loss = -(strategy_returns.mean() / (strategy_returns.std() + self.sharpe_eps)) # --- Volatility calibration --- # Match Q90-Q10 spread to 2× actual σ so the prediction interval tracks # realised volatility rather than collapsing to a constant. pred_spread = (y_pred[:, :, self._q90_idx] - y_pred[:, :, self._q10_idx]).mean() actual_std = y_actual_f.std() + self.sharpe_eps vol_loss = torch.abs(pred_spread - 2.0 * actual_std) # --- Median amplitude penalty --- # vol_loss only targets the Q10-Q90 band width; the model can widen bands # while keeping median predictions flat. This term directly penalises the # median for having lower variance than actual returns. # relu(1 - VR) fires when pred_std < actual_std; zero otherwise. median_std = median_pred.std() + self.sharpe_eps vr = median_std / actual_std amplitude_loss = ( torch.relu(1.0 - vr) # under-variance: VR < 1 → strong penalty + 0.25 * torch.relu(vr - 1.5) # over-variance: VR > 1.5 → gentle penalty ) # --- Quantile (pinball) loss --- q_loss = self.quantile_loss(y_pred, y_actual) # --- Weighted combination --- # calibration = quantile bands + band width + median amplitude # w_quantile + w_sharpe = 1.0 w_sharpe = 1.0 - self.lambda_quantile calibration = q_loss + self.lambda_vol * (vol_loss + amplitude_loss) total = self.lambda_quantile * calibration + w_sharpe * sharpe_loss return total @classmethod def from_config(cls, cfg: ASROConfig, quantiles: Sequence[float]) -> "AdaptiveSharpeRatioLoss": return cls( quantiles=quantiles, lambda_vol=cfg.lambda_vol, lambda_quantile=cfg.lambda_quantile, risk_free_rate=cfg.risk_free_rate, )