ifieryarrows commited on
Commit
5f4b46f
·
verified ·
1 Parent(s): 2cd6bdb

Sync from GitHub (tests passed)

Browse files
Files changed (1) hide show
  1. deep_learning/models/tft_copper.py +28 -1
deep_learning/models/tft_copper.py CHANGED
@@ -23,6 +23,33 @@ from deep_learning.models.losses import AdaptiveSharpeRatioLoss, CombinedQuantil
23
  logger = logging.getLogger(__name__)
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def create_tft_model(
27
  training_dataset,
28
  cfg: Optional[TFTASROConfig] = None,
@@ -48,7 +75,7 @@ def create_tft_model(
48
  quantiles = list(cfg.model.quantiles)
49
 
50
  if use_asro:
51
- loss = AdaptiveSharpeRatioLoss.from_config(cfg.asro, quantiles)
52
  logger.info("Using ASRO loss (lambda_vol=%.2f, lambda_quantile=%.2f)", cfg.asro.lambda_vol, cfg.asro.lambda_quantile)
53
  else:
54
  loss = QuantileLoss(quantiles=quantiles)
 
23
  logger = logging.getLogger(__name__)
24
 
25
 
26
+ def _build_asro_pf_loss(asro_cfg, quantiles: list):
27
+ """
28
+ Build an ASRO loss that satisfies pytorch_forecasting >= 1.0 requirements.
29
+
30
+ pytorch_forecasting requires the loss to be a torchmetrics ``Metric``
31
+ subclass. We subclass ``QuantileLoss`` (which already satisfies this)
32
+ and override its ``loss()`` method with our ASRO logic.
33
+ """
34
+ from pytorch_forecasting.metrics import QuantileLoss
35
+
36
+ # Capture ASRO nn.Module so it runs inside the Metric wrapper
37
+ _asro_module = AdaptiveSharpeRatioLoss.from_config(asro_cfg, quantiles)
38
+
39
+ class _ASROPFLoss(QuantileLoss):
40
+ """pytorch_forecasting-compatible ASRO loss wrapper."""
41
+
42
+ def loss(self, y_pred: torch.Tensor, target) -> torch.Tensor: # type: ignore[override]
43
+ # pytorch_forecasting passes target as a tuple (values, weights)
44
+ if isinstance(target, (list, tuple)):
45
+ y_actual = target[0]
46
+ else:
47
+ y_actual = target
48
+ return _asro_module(y_pred, y_actual)
49
+
50
+ return _ASROPFLoss(quantiles=quantiles)
51
+
52
+
53
  def create_tft_model(
54
  training_dataset,
55
  cfg: Optional[TFTASROConfig] = None,
 
75
  quantiles = list(cfg.model.quantiles)
76
 
77
  if use_asro:
78
+ loss = _build_asro_pf_loss(cfg.asro, quantiles)
79
  logger.info("Using ASRO loss (lambda_vol=%.2f, lambda_quantile=%.2f)", cfg.asro.lambda_vol, cfg.asro.lambda_quantile)
80
  else:
81
  loss = QuantileLoss(quantiles=quantiles)