Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files
deep_learning/models/tft_copper.py
CHANGED
|
@@ -23,6 +23,33 @@ from deep_learning.models.losses import AdaptiveSharpeRatioLoss, CombinedQuantil
|
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def create_tft_model(
|
| 27 |
training_dataset,
|
| 28 |
cfg: Optional[TFTASROConfig] = None,
|
|
@@ -48,7 +75,7 @@ def create_tft_model(
|
|
| 48 |
quantiles = list(cfg.model.quantiles)
|
| 49 |
|
| 50 |
if use_asro:
|
| 51 |
-
loss =
|
| 52 |
logger.info("Using ASRO loss (lambda_vol=%.2f, lambda_quantile=%.2f)", cfg.asro.lambda_vol, cfg.asro.lambda_quantile)
|
| 53 |
else:
|
| 54 |
loss = QuantileLoss(quantiles=quantiles)
|
|
|
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
|
| 26 |
+
def _build_asro_pf_loss(asro_cfg, quantiles: list):
|
| 27 |
+
"""
|
| 28 |
+
Build an ASRO loss that satisfies pytorch_forecasting >= 1.0 requirements.
|
| 29 |
+
|
| 30 |
+
pytorch_forecasting requires the loss to be a torchmetrics ``Metric``
|
| 31 |
+
subclass. We subclass ``QuantileLoss`` (which already satisfies this)
|
| 32 |
+
and override its ``loss()`` method with our ASRO logic.
|
| 33 |
+
"""
|
| 34 |
+
from pytorch_forecasting.metrics import QuantileLoss
|
| 35 |
+
|
| 36 |
+
# Capture ASRO nn.Module so it runs inside the Metric wrapper
|
| 37 |
+
_asro_module = AdaptiveSharpeRatioLoss.from_config(asro_cfg, quantiles)
|
| 38 |
+
|
| 39 |
+
class _ASROPFLoss(QuantileLoss):
|
| 40 |
+
"""pytorch_forecasting-compatible ASRO loss wrapper."""
|
| 41 |
+
|
| 42 |
+
def loss(self, y_pred: torch.Tensor, target) -> torch.Tensor: # type: ignore[override]
|
| 43 |
+
# pytorch_forecasting passes target as a tuple (values, weights)
|
| 44 |
+
if isinstance(target, (list, tuple)):
|
| 45 |
+
y_actual = target[0]
|
| 46 |
+
else:
|
| 47 |
+
y_actual = target
|
| 48 |
+
return _asro_module(y_pred, y_actual)
|
| 49 |
+
|
| 50 |
+
return _ASROPFLoss(quantiles=quantiles)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
def create_tft_model(
|
| 54 |
training_dataset,
|
| 55 |
cfg: Optional[TFTASROConfig] = None,
|
|
|
|
| 75 |
quantiles = list(cfg.model.quantiles)
|
| 76 |
|
| 77 |
if use_asro:
|
| 78 |
+
loss = _build_asro_pf_loss(cfg.asro, quantiles)
|
| 79 |
logger.info("Using ASRO loss (lambda_vol=%.2f, lambda_quantile=%.2f)", cfg.asro.lambda_vol, cfg.asro.lambda_quantile)
|
| 80 |
else:
|
| 81 |
loss = QuantileLoss(quantiles=quantiles)
|