ifieryarrows commited on
Commit
e411cee
·
verified ·
1 Parent(s): 9098220

Sync from GitHub (tests passed)

Browse files
app/quality_gate.py CHANGED
@@ -21,12 +21,16 @@ def evaluate_quality_gate(
21
  tail_capture: Optional[float] = None,
22
  quantile_crossing_rate: Optional[float] = None,
23
  median_sort_gap_max: Optional[float] = None,
 
 
24
  weekly_directional_accuracy: Optional[float] = None,
25
  weekly_magnitude_ratio: Optional[float] = None,
26
  weekly_tail_capture_rate: Optional[float] = None,
27
  weekly_pi80_coverage: Optional[float] = None,
 
28
  weekly_pi80_width_ratio: Optional[float] = None,
29
  weekly_pi96_coverage: Optional[float] = None,
 
30
  weekly_pi96_width_ratio: Optional[float] = None,
31
  weekly_quantile_crossing_rate: Optional[float] = None,
32
  weekly_sorted_quantile_crossing_rate: Optional[float] = None,
@@ -74,6 +78,8 @@ def evaluate_quality_gate(
74
  reasons.append(
75
  f"WeeklyPI80Overwide={weekly_pi80_width_ratio:.4f} with coverage={weekly_pi80_coverage:.4f}"
76
  )
 
 
77
 
78
  if weekly_pi96_coverage is None:
79
  reasons.append("Missing weekly_pi96_coverage")
@@ -82,33 +88,63 @@ def evaluate_quality_gate(
82
  reasons.append("Missing weekly_pi96_width_ratio")
83
  elif weekly_pi96_width_ratio > 3.0:
84
  reasons.append(f"WeeklyPI96WidthRatio={weekly_pi96_width_ratio:.4f} > 3.0")
 
 
85
 
86
  if weekly_quantile_crossing_rate is None:
87
  reasons.append("Missing weekly_quantile_crossing_rate")
88
- elif weekly_quantile_crossing_rate > 0.05:
89
- reasons.append(f"WeeklyQuantileCrossing={weekly_quantile_crossing_rate:.4f} > 0.05")
 
 
90
 
91
  if weekly_sorted_quantile_crossing_rate is None:
92
  reasons.append("Missing weekly_sorted_quantile_crossing_rate")
93
- elif weekly_sorted_quantile_crossing_rate > 0.0:
94
- reasons.append(
95
- f"WeeklySortedQuantileCrossing={weekly_sorted_quantile_crossing_rate:.4f} > 0.0"
96
  )
97
 
98
- if weekly_median_sort_gap_max is not None and weekly_median_sort_gap_max > 0.005:
99
- reasons.append(f"WeeklyMedianSortGapMax={weekly_median_sort_gap_max:.4f} > 0.005")
 
 
100
 
101
  if sharpe < -0.30:
102
  reasons.append(f"Sharpe={sharpe:.4f} < -0.30")
103
- if vr < 0.2 or vr > 3.0:
104
- reasons.append(f"VR={vr:.4f} outside [0.2, 3.0]")
105
  if tail_capture is not None and tail_capture < 0.35:
106
  reasons.append(f"TailCapture={tail_capture:.4f} < 0.35")
107
  if quantile_crossing_rate is None:
108
  reasons.append("Missing quantile_crossing_rate")
109
- elif quantile_crossing_rate > 0.20:
110
- reasons.append(f"QuantileCrossing={quantile_crossing_rate:.4f} > 0.20")
111
- if median_sort_gap_max is not None and median_sort_gap_max > 0.01:
112
- reasons.append(f"MedianSortGapMax={median_sort_gap_max:.4f} > 0.01")
 
 
 
 
113
 
114
  return len(reasons) == 0, reasons
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  tail_capture: Optional[float] = None,
22
  quantile_crossing_rate: Optional[float] = None,
23
  median_sort_gap_max: Optional[float] = None,
24
+ pi80_width: Optional[float] = None,
25
+ pi96_width: Optional[float] = None,
26
  weekly_directional_accuracy: Optional[float] = None,
27
  weekly_magnitude_ratio: Optional[float] = None,
28
  weekly_tail_capture_rate: Optional[float] = None,
29
  weekly_pi80_coverage: Optional[float] = None,
30
+ weekly_pi80_width: Optional[float] = None,
31
  weekly_pi80_width_ratio: Optional[float] = None,
32
  weekly_pi96_coverage: Optional[float] = None,
33
+ weekly_pi96_width: Optional[float] = None,
34
  weekly_pi96_width_ratio: Optional[float] = None,
35
  weekly_quantile_crossing_rate: Optional[float] = None,
36
  weekly_sorted_quantile_crossing_rate: Optional[float] = None,
 
78
  reasons.append(
79
  f"WeeklyPI80Overwide={weekly_pi80_width_ratio:.4f} with coverage={weekly_pi80_coverage:.4f}"
80
  )
81
+ if weekly_pi80_width is not None and weekly_pi80_width < 0.0:
82
+ reasons.append(f"WeeklyPI80Width={weekly_pi80_width:.4f} < 0.0")
83
 
84
  if weekly_pi96_coverage is None:
85
  reasons.append("Missing weekly_pi96_coverage")
 
88
  reasons.append("Missing weekly_pi96_width_ratio")
89
  elif weekly_pi96_width_ratio > 3.0:
90
  reasons.append(f"WeeklyPI96WidthRatio={weekly_pi96_width_ratio:.4f} > 3.0")
91
+ if weekly_pi96_width is not None and weekly_pi96_width < 0.0:
92
+ reasons.append(f"WeeklyPI96Width={weekly_pi96_width:.4f} < 0.0")
93
 
94
  if weekly_quantile_crossing_rate is None:
95
  reasons.append("Missing weekly_quantile_crossing_rate")
96
+ elif weekly_quantile_crossing_rate > 0.001:
97
+ raise AssertionError(
98
+ f"WeeklyPublicQuantileCrossing={weekly_quantile_crossing_rate:.4f} > 0.001"
99
+ )
100
 
101
  if weekly_sorted_quantile_crossing_rate is None:
102
  reasons.append("Missing weekly_sorted_quantile_crossing_rate")
103
+ elif weekly_sorted_quantile_crossing_rate > 0.001:
104
+ raise AssertionError(
105
+ f"WeeklyOrderedQuantileCrossing={weekly_sorted_quantile_crossing_rate:.4f} > 0.001"
106
  )
107
 
108
+ if weekly_median_sort_gap_max is not None and weekly_median_sort_gap_max > 0.001:
109
+ raise AssertionError(
110
+ f"WeeklyOrderedMedianSortGapMax={weekly_median_sort_gap_max:.4f} > 0.001"
111
+ )
112
 
113
  if sharpe < -0.30:
114
  reasons.append(f"Sharpe={sharpe:.4f} < -0.30")
 
 
115
  if tail_capture is not None and tail_capture < 0.35:
116
  reasons.append(f"TailCapture={tail_capture:.4f} < 0.35")
117
  if quantile_crossing_rate is None:
118
  reasons.append("Missing quantile_crossing_rate")
119
+ elif quantile_crossing_rate > 0.001:
120
+ raise AssertionError(f"PublicQuantileCrossing={quantile_crossing_rate:.4f} > 0.001")
121
+ if median_sort_gap_max is not None and median_sort_gap_max > 0.001:
122
+ raise AssertionError(f"OrderedMedianSortGapMax={median_sort_gap_max:.4f} > 0.001")
123
+ if pi80_width is not None and pi80_width < 0.0:
124
+ reasons.append(f"PI80Width={pi80_width:.4f} < 0.0")
125
+ if pi96_width is not None and pi96_width < 0.0:
126
+ reasons.append(f"PI96Width={pi96_width:.4f} < 0.0")
127
 
128
  return len(reasons) == 0, reasons
129
+
130
+
131
+ def evaluate_quality_gate_warnings(
132
+ vr: float,
133
+ mae_vs_naive_zero: Optional[float] = None,
134
+ weekly_mae_vs_naive_zero: Optional[float] = None,
135
+ ) -> List[str]:
136
+ """Return stabilization warnings that do not fail promotion yet."""
137
+ warnings: list[str] = []
138
+ if vr > 2.5:
139
+ warnings.append(f"VR={vr:.4f} > 2.5 - model overdispersed")
140
+ if vr < 0.4:
141
+ warnings.append(f"VR={vr:.4f} < 0.4 - model underdispersed")
142
+ if mae_vs_naive_zero is not None and mae_vs_naive_zero > 1.25:
143
+ warnings.append(
144
+ f"MAEvsNaiveZero={mae_vs_naive_zero:.4f} > 1.25 - worse than warning baseline"
145
+ )
146
+ if weekly_mae_vs_naive_zero is not None and weekly_mae_vs_naive_zero > 1.25:
147
+ warnings.append(
148
+ f"WeeklyMAEvsNaiveZero={weekly_mae_vs_naive_zero:.4f} > 1.25 - worse than warning baseline"
149
+ )
150
+ return warnings
deep_learning/config.py CHANGED
@@ -136,15 +136,10 @@ class ASROConfig:
136
 
137
  @dataclass(frozen=True)
138
  class WeeklyLossConfig:
139
- lambda_weekly_quantile: float = 0.60
140
- lambda_t1_quantile: float = 0.10
 
141
  lambda_directional: float = 0.10
142
- lambda_magnitude: float = 0.55
143
- lambda_vol: float = 0.35
144
- lambda_crossing: float = 7.0
145
- lambda_sanity: float = 0.20
146
- lambda_width: float = 0.50
147
- lambda_tail_width: float = 0.30
148
 
149
 
150
  @dataclass(frozen=True)
 
136
 
137
  @dataclass(frozen=True)
138
  class WeeklyLossConfig:
139
+ lambda_weekly_quantile: float = 0.55
140
+ lambda_t1_quantile: float = 0.15
141
+ lambda_dispersion: float = 0.20
142
  lambda_directional: float = 0.10
 
 
 
 
 
 
143
 
144
 
145
  @dataclass(frozen=True)
deep_learning/models/hub.py CHANGED
@@ -48,6 +48,108 @@ def _sha256_file(path: Path) -> str:
48
  return digest.hexdigest()
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def build_artifact_manifest(local_dir: str | Path) -> dict:
52
  """Build a SHA256 manifest for every present TFT artifact except itself."""
53
  local_dir = Path(local_dir)
@@ -68,6 +170,7 @@ def build_artifact_manifest(local_dir: str | Path) -> dict:
68
  "manifest_version": 1,
69
  "generated_at": datetime.now(timezone.utc).isoformat(),
70
  "artifacts": artifacts,
 
71
  }
72
 
73
 
@@ -180,6 +283,12 @@ def validate_tft_artifact_set(local_dir: str | Path) -> bool:
180
  return True
181
 
182
 
 
 
 
 
 
 
183
  def upload_tft_artifacts(
184
  local_dir: str | Path,
185
  repo_id: str,
@@ -208,6 +317,9 @@ def upload_tft_artifacts(
208
  if not validate_tft_artifact_set(local_dir):
209
  logger.warning("TFT artifact manifest validation failed before upload")
210
  return False
 
 
 
211
 
212
  files_to_upload = [
213
  local_dir / name for name in _ARTIFACTS if (local_dir / name).exists()
 
48
  return digest.hexdigest()
49
 
50
 
51
+ def _load_json(path: Path) -> dict:
52
+ if not path.exists():
53
+ return {}
54
+ try:
55
+ return json.loads(path.read_text(encoding="utf-8"))
56
+ except Exception as exc:
57
+ logger.warning("Could not read JSON artifact %s: %s", path, exc)
58
+ return {}
59
+
60
+
61
+ def build_artifact_health(local_dir: str | Path) -> dict:
62
+ """Build promotion/inference health metadata for the TFT artifact set."""
63
+ local_dir = Path(local_dir)
64
+ metadata_path = local_dir / "tft_metadata.json"
65
+ checkpoint_present = (local_dir / "best_tft_asro.ckpt").exists()
66
+ metadata_present = metadata_path.exists()
67
+ conformal_present = (local_dir / "conformal_calibration.json").exists()
68
+
69
+ metadata = _load_json(metadata_path)
70
+ config = metadata.get("config") or {}
71
+ metrics = metadata.get("test_metrics") or {}
72
+ optuna = _load_json(local_dir / "optuna_results.json")
73
+ structural_report = optuna.get("structural_invalidity_report") or {}
74
+ best_preflight = optuna.get("best_trial_preflight") or {}
75
+
76
+ quality_gate_passed = False
77
+ gate_error = None
78
+ if metrics:
79
+ try:
80
+ from app.quality_gate import evaluate_quality_gate
81
+
82
+ quality_gate_passed, reasons = evaluate_quality_gate(
83
+ da=float(metrics.get("directional_accuracy", 0.5)),
84
+ sharpe=float(metrics.get("sharpe_ratio", 0.0)),
85
+ vr=float(metrics.get("variance_ratio", 1.0)),
86
+ tail_capture=metrics.get("tail_capture_rate"),
87
+ quantile_crossing_rate=metrics.get("quantile_crossing_rate"),
88
+ median_sort_gap_max=metrics.get("median_sort_gap_max"),
89
+ pi80_width=metrics.get("pi80_width"),
90
+ pi96_width=metrics.get("pi96_width"),
91
+ weekly_directional_accuracy=metrics.get("weekly_directional_accuracy"),
92
+ weekly_magnitude_ratio=metrics.get("weekly_magnitude_ratio"),
93
+ weekly_tail_capture_rate=metrics.get("weekly_tail_capture_rate"),
94
+ weekly_pi80_coverage=metrics.get("weekly_pi80_coverage"),
95
+ weekly_pi80_width=metrics.get("weekly_pi80_width"),
96
+ weekly_pi80_width_ratio=metrics.get("weekly_pi80_width_ratio"),
97
+ weekly_pi96_coverage=metrics.get("weekly_pi96_coverage"),
98
+ weekly_pi96_width=metrics.get("weekly_pi96_width"),
99
+ weekly_pi96_width_ratio=metrics.get("weekly_pi96_width_ratio"),
100
+ weekly_quantile_crossing_rate=metrics.get("weekly_quantile_crossing_rate"),
101
+ weekly_sorted_quantile_crossing_rate=metrics.get(
102
+ "weekly_sorted_quantile_crossing_rate"
103
+ ),
104
+ weekly_median_sort_gap_max=metrics.get("weekly_median_sort_gap_max"),
105
+ weekly_sample_count=metrics.get("weekly_sample_count"),
106
+ )
107
+ if not quality_gate_passed:
108
+ gate_error = "; ".join(reasons)
109
+ except Exception as exc:
110
+ gate_error = str(exc)
111
+ quality_gate_passed = False
112
+ else:
113
+ gate_error = "missing test_metrics"
114
+
115
+ safe = bool(quality_gate_passed and checkpoint_present and metadata_present)
116
+ next_required_action = "No action required; artifact is promotable."
117
+ if not safe:
118
+ next_required_action = (
119
+ gate_error
120
+ or "Run deterministic validation and pass the weekly quality gate before upload."
121
+ )
122
+
123
+ return {
124
+ "forecast_contract_version": (
125
+ metadata.get("forecast_contract_version")
126
+ or config.get("forecast_contract_version")
127
+ ),
128
+ "monotonic_quantile_transform": bool(
129
+ config.get("monotonic_quantile_transform")
130
+ or metadata.get("monotonic_quantile_transform")
131
+ ),
132
+ "checkpoint_present": checkpoint_present,
133
+ "metadata_present": metadata_present,
134
+ "conformal_present": conformal_present,
135
+ "quality_gate_passed": quality_gate_passed,
136
+ "best_trial_preflight_passed": bool(best_preflight.get("preflight_passed", False)),
137
+ "structural_invalidity_verdict": structural_report.get("verdict", "UNKNOWN"),
138
+ "safe_to_upload_to_hub": safe,
139
+ "safe_for_inference": safe,
140
+ "raw_quantile_crossing_rate": metrics.get("raw_quantile_crossing_rate"),
141
+ "ordered_quantile_crossing_rate": metrics.get("ordered_quantile_crossing_rate"),
142
+ "public_quantile_crossing_rate": metrics.get(
143
+ "public_quantile_crossing_rate",
144
+ metrics.get("quantile_crossing_rate"),
145
+ ),
146
+ "variance_ratio": metrics.get("variance_ratio"),
147
+ "mae_vs_naive_zero": metrics.get("mae_vs_naive_zero"),
148
+ "weekly_mae_vs_naive_zero": metrics.get("weekly_mae_vs_naive_zero"),
149
+ "next_required_action": next_required_action,
150
+ }
151
+
152
+
153
  def build_artifact_manifest(local_dir: str | Path) -> dict:
154
  """Build a SHA256 manifest for every present TFT artifact except itself."""
155
  local_dir = Path(local_dir)
 
170
  "manifest_version": 1,
171
  "generated_at": datetime.now(timezone.utc).isoformat(),
172
  "artifacts": artifacts,
173
+ "artifact_health": build_artifact_health(local_dir),
174
  }
175
 
176
 
 
283
  return True
284
 
285
 
286
+ def _manifest_safe_to_upload(local_dir: str | Path) -> bool:
287
+ manifest = _load_json(Path(local_dir) / "artifact_manifest.json")
288
+ health = manifest.get("artifact_health") or {}
289
+ return bool(health.get("safe_to_upload_to_hub"))
290
+
291
+
292
  def upload_tft_artifacts(
293
  local_dir: str | Path,
294
  repo_id: str,
 
317
  if not validate_tft_artifact_set(local_dir):
318
  logger.warning("TFT artifact manifest validation failed before upload")
319
  return False
320
+ if not _manifest_safe_to_upload(local_dir):
321
+ logger.warning("TFT artifact health is not safe for Hub upload; upload skipped")
322
+ return False
323
 
324
  files_to_upload = [
325
  local_dir / name for name in _ARTIFACTS if (local_dir / name).exists()
deep_learning/models/monotonic_quantiles.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def enforce_monotonic_quantiles(
8
+ y_pred: torch.Tensor,
9
+ median_idx: int = 3,
10
+ min_gap: float = 1e-5,
11
+ gap_scale: float = 0.01,
12
+ init_bias: float = -3.0,
13
+ ) -> torch.Tensor:
14
+ """
15
+ Transform unconstrained quantile outputs into structurally monotonic
16
+ quantile outputs.
17
+
18
+ The median dimension is preserved exactly. Lower/upper quantile distances
19
+ are positive by construction and scaled for log-return targets.
20
+ """
21
+ base = y_pred[..., median_idx]
22
+
23
+ lower_raw = y_pred[..., :median_idx]
24
+ upper_raw = y_pred[..., median_idx + 1 :]
25
+
26
+ lower_steps = min_gap + gap_scale * F.softplus(
27
+ torch.flip(lower_raw, dims=[-1]) + init_bias
28
+ )
29
+ upper_steps = min_gap + gap_scale * F.softplus(upper_raw + init_bias)
30
+
31
+ lower_from_median = torch.cumsum(lower_steps, dim=-1)
32
+ upper_from_median = torch.cumsum(upper_steps, dim=-1)
33
+
34
+ lower = base.unsqueeze(-1) - lower_from_median
35
+ lower = torch.flip(lower, dims=[-1])
36
+ upper = base.unsqueeze(-1) + upper_from_median
37
+
38
+ ordered = torch.cat([lower, base.unsqueeze(-1), upper], dim=-1)
39
+
40
+ assert ordered.shape == y_pred.shape, (
41
+ f"Monotonic transform output shape {ordered.shape} "
42
+ f"does not match input shape {y_pred.shape}"
43
+ )
44
+ return ordered
45
+
46
+
47
+ def validate_monotonicity(
48
+ y_pred: torch.Tensor,
49
+ tolerance: float = 1e-6,
50
+ ) -> dict:
51
+ """Return crossing diagnostics for an ordered quantile tensor."""
52
+ diffs = y_pred[..., 1:] - y_pred[..., :-1]
53
+ violations = diffs < -tolerance
54
+ crossing_rate = violations.float().mean().item()
55
+ max_violation = (
56
+ (-diffs[violations]).max().item() if violations.any().item() else 0.0
57
+ )
58
+
59
+ return {
60
+ "crossing_rate": crossing_rate,
61
+ "max_violation": max_violation,
62
+ "is_valid": crossing_rate == 0.0,
63
+ }
deep_learning/models/tft_copper.py CHANGED
@@ -15,10 +15,15 @@ from pathlib import Path
15
  from typing import Any, Dict, Optional, Sequence
16
 
17
  import torch
 
18
  import numpy as np
19
 
20
  from deep_learning.contract import RETURN_SPACE, log_to_simple_return
21
  from deep_learning.config import TFTASROConfig, get_tft_config
 
 
 
 
22
  from deep_learning.models.losses import (
23
  AdaptiveSharpeRatioLoss,
24
  CombinedQuantileLoss,
@@ -131,38 +136,76 @@ try:
131
  def __init__(
132
  self,
133
  quantiles: list,
134
- lambda_weekly_quantile: float = 0.60,
135
- lambda_t1_quantile: float = 0.10,
 
136
  lambda_directional: float = 0.10,
137
- lambda_magnitude: float = 0.55,
138
- lambda_vol: float = 0.35,
139
- lambda_crossing: float = 7.0,
140
- lambda_sanity: float = 0.20,
141
- lambda_width: float = 0.50,
142
- lambda_tail_width: float = 0.30,
143
- sharpe_eps: float = 1e-6,
144
- daily_log_return_bound: float = 0.08,
145
- weekly_log_return_bound: float = 0.20,
146
  ):
147
  super().__init__(quantiles=quantiles)
148
  self.lambda_weekly_quantile = lambda_weekly_quantile
149
  self.lambda_t1_quantile = lambda_t1_quantile
 
150
  self.lambda_directional = lambda_directional
151
- self.lambda_magnitude = lambda_magnitude
152
- self.lambda_vol = lambda_vol
153
- self.lambda_crossing = lambda_crossing
154
- self.lambda_sanity = lambda_sanity
155
- self.lambda_width = lambda_width
156
- self.lambda_tail_width = lambda_tail_width
157
  self.sharpe_eps = sharpe_eps
158
- self.daily_log_return_bound = daily_log_return_bound
159
- self.weekly_log_return_bound = weekly_log_return_bound
160
  self.median_idx = len(quantiles) // 2
161
- q = list(quantiles)
162
- self._q02_idx = q.index(0.02) if 0.02 in q else 0
163
- self._q10_idx = q.index(0.10) if 0.10 in q else 1
164
- self._q90_idx = q.index(0.90) if 0.90 in q else len(q) - 2
165
- self._q98_idx = q.index(0.98) if 0.98 in q else len(q) - 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  def _pinball(self, pred: torch.Tensor, actual: torch.Tensor) -> torch.Tensor:
168
  q = torch.tensor(self.quantiles, device=pred.device, dtype=pred.dtype).view(1, -1)
@@ -178,70 +221,48 @@ try:
178
  y_actual = y_actual.float()
179
  y_pred = y_pred.float()
180
 
181
- median_path = y_pred[..., self.median_idx]
182
- pred_weekly_quantiles = y_pred.sum(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  actual_weekly = y_actual.sum(dim=1)
184
 
185
  weekly_q_loss = self._pinball(pred_weekly_quantiles, actual_weekly)
186
- t1_q_loss = super().loss(y_pred[:, 0:1, :], y_actual[:, 0:1])
187
 
188
  pred_weekly_median = median_path.sum(dim=1)
189
- signal = torch.tanh(pred_weekly_median * 20.0)
190
- weekly_directional = -(signal * actual_weekly).mean() / (
191
- (signal * actual_weekly).std() + self.sharpe_eps
192
- )
193
 
194
- abs_actual = actual_weekly.abs()
195
- material_mask = abs_actual > (abs_actual.median() + self.sharpe_eps)
196
- global_magnitude_loss = torch.abs(
197
- torch.log(
198
- (pred_weekly_median.abs() + self.sharpe_eps)
199
- / (actual_weekly.abs() + self.sharpe_eps)
200
- )
201
- ).mean()
202
- if material_mask.any():
203
- pred_abs = pred_weekly_median[material_mask].abs()
204
- true_abs = actual_weekly[material_mask].abs()
205
- material_magnitude_loss = torch.abs(
206
- torch.log((pred_abs + self.sharpe_eps) / (true_abs + self.sharpe_eps))
207
- ).mean()
208
- else:
209
- material_magnitude_loss = y_pred.new_tensor(0.0)
210
- magnitude_loss = 0.5 * global_magnitude_loss + 0.5 * material_magnitude_loss
211
 
212
- weekly_spread = (
213
- pred_weekly_quantiles[:, self._q90_idx]
214
- - pred_weekly_quantiles[:, self._q10_idx]
215
- )
216
- actual_weekly_std = actual_weekly.std() + self.sharpe_eps
217
- target_spread = 2.56 * actual_weekly_std
218
- mean_weekly_spread = weekly_spread.mean()
219
- vol_loss = torch.abs(mean_weekly_spread - target_spread)
220
- width_ratio = mean_weekly_spread / (target_spread + self.sharpe_eps)
221
- safe_width_ratio = torch.clamp(width_ratio + self.sharpe_eps, min=1e-6)
222
- width_loss = torch.abs(torch.log(safe_width_ratio))
223
- width_loss = width_loss + torch.relu(width_ratio - 2.0).pow(2)
224
-
225
- weekly_tail_spread = (
226
- pred_weekly_quantiles[:, self._q98_idx]
227
- - pred_weekly_quantiles[:, self._q02_idx]
228
- )
229
- target_tail_spread = 4.10 * actual_weekly_std
230
- tail_width_ratio = weekly_tail_spread.mean() / (target_tail_spread + self.sharpe_eps)
231
- safe_tail_width_ratio = torch.clamp(tail_width_ratio + self.sharpe_eps, min=1e-6)
232
- tail_width_loss = torch.abs(torch.log(safe_tail_width_ratio))
233
- tail_width_loss = tail_width_loss + torch.relu(tail_width_ratio - 3.0).pow(2)
234
- daily_crossing_loss = quantile_crossing_penalty(y_pred)
235
- weekly_crossing_loss = quantile_crossing_penalty(pred_weekly_quantiles.unsqueeze(1))
236
- crossing_loss = daily_crossing_loss + weekly_crossing_loss
237
-
238
- daily_bound_loss = torch.relu(
239
- median_path.abs() - self.daily_log_return_bound
240
- ).pow(2).mean()
241
- weekly_bound_loss = torch.relu(
242
- pred_weekly_median.abs() - self.weekly_log_return_bound
243
- ).pow(2).mean()
244
- sanity_loss = daily_bound_loss + weekly_bound_loss
245
 
246
  def _to_scalar(x: torch.Tensor) -> torch.Tensor:
247
  # pytorch_forecasting metrics can return per-sample tensors;
@@ -249,17 +270,24 @@ try:
249
  # boolean comparisons in tests and stable optimizer behaviour.
250
  return x.mean() if x.ndim > 0 else x
251
 
252
- return (
 
 
 
 
253
  self.lambda_weekly_quantile * _to_scalar(weekly_q_loss)
254
  + self.lambda_t1_quantile * _to_scalar(t1_q_loss)
255
- + self.lambda_directional * _to_scalar(weekly_directional)
256
- + self.lambda_magnitude * _to_scalar(magnitude_loss)
257
- + self.lambda_vol * _to_scalar(vol_loss)
258
- + self.lambda_width * _to_scalar(width_loss)
259
- + self.lambda_tail_width * _to_scalar(tail_width_loss)
260
- + self.lambda_crossing * _to_scalar(crossing_loss)
261
- + self.lambda_sanity * _to_scalar(sanity_loss)
262
  )
 
 
 
 
 
 
 
 
263
 
264
  except ImportError:
265
  ASROPFLoss = None # type: ignore[assignment,misc]
@@ -295,25 +323,15 @@ def create_tft_model(
295
  quantiles=quantiles,
296
  lambda_weekly_quantile=cfg.weekly_loss.lambda_weekly_quantile,
297
  lambda_t1_quantile=cfg.weekly_loss.lambda_t1_quantile,
 
298
  lambda_directional=cfg.weekly_loss.lambda_directional,
299
- lambda_magnitude=cfg.weekly_loss.lambda_magnitude,
300
- lambda_vol=cfg.weekly_loss.lambda_vol,
301
- lambda_crossing=cfg.weekly_loss.lambda_crossing,
302
- lambda_sanity=cfg.weekly_loss.lambda_sanity,
303
- lambda_width=cfg.weekly_loss.lambda_width,
304
- lambda_tail_width=cfg.weekly_loss.lambda_tail_width,
305
  )
306
  logger.info(
307
- "Using weekly ASRO loss | weekly_q=%.2f t1_q=%.2f dir=%.2f mag=%.2f vol=%.2f width=%.2f tail_width=%.2f crossing=%.2f sanity=%.2f",
308
  cfg.weekly_loss.lambda_weekly_quantile,
309
  cfg.weekly_loss.lambda_t1_quantile,
 
310
  cfg.weekly_loss.lambda_directional,
311
- cfg.weekly_loss.lambda_magnitude,
312
- cfg.weekly_loss.lambda_vol,
313
- cfg.weekly_loss.lambda_width,
314
- cfg.weekly_loss.lambda_tail_width,
315
- cfg.weekly_loss.lambda_crossing,
316
- cfg.weekly_loss.lambda_sanity,
317
  )
318
  elif use_asro and ASROPFLoss is not None:
319
  loss = ASROPFLoss(
@@ -490,20 +508,28 @@ def _format_prediction_legacy_simple_return(
490
  quantile_diffs = np.diff(raw_pred, axis=-1) if raw_pred.shape[-1] > 1 else np.array([])
491
  crossing_mask = quantile_diffs < -1e-12 if quantile_diffs.size else np.array([], dtype=bool)
492
  quantile_crossing_detected = bool(crossing_mask.any())
493
- quantile_crossing_rate = float(crossing_mask.mean()) if crossing_mask.size else 0.0
494
- sorted_pred = np.sort(raw_pred, axis=-1)
 
 
 
 
 
 
 
 
 
495
  median_sort_gap = float(
496
  np.max(np.abs(raw_pred[..., median_idx] - sorted_pred[..., median_idx]))
497
  )
498
  if quantile_crossing_detected:
499
  logger.error(
500
  "format_prediction: non-monotonic quantiles detected "
501
- "(crossing_rate=%.3f, max_median_sort_gap=%.4f); public output "
502
- "will use monotonic sorted quantiles and expose raw_quantiles for audit.",
503
- quantile_crossing_rate,
504
  median_sort_gap,
505
  )
506
- pred = sorted_pred
507
 
508
  if _math.isnan(baseline_price) or _math.isinf(baseline_price) or baseline_price <= 0:
509
  logger.warning(
@@ -643,18 +669,42 @@ def format_prediction(
643
  quantile_diffs = np.diff(raw_pred, axis=-1) if raw_pred.shape[-1] > 1 else np.array([])
644
  crossing_mask = quantile_diffs < -1e-12 if quantile_diffs.size else np.array([], dtype=bool)
645
  quantile_crossing_detected = bool(crossing_mask.any())
646
- quantile_crossing_rate = float(crossing_mask.mean()) if crossing_mask.size else 0.0
647
- sorted_pred = np.sort(raw_pred, axis=-1)
648
- median_sort_gap = float(np.max(np.abs(raw_pred[..., median_idx] - sorted_pred[..., median_idx])))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
  if quantile_crossing_detected:
650
  logger.error(
651
  "format_prediction: non-monotonic quantiles detected "
652
- "(crossing_rate=%.3f, max_median_sort_gap=%.4f); public output "
653
- "will use monotonic sorted quantiles and expose raw_quantiles for audit.",
654
- quantile_crossing_rate,
655
- median_sort_gap,
656
  )
657
- pred = sorted_pred
658
 
659
  if _math.isnan(baseline_price) or _math.isinf(baseline_price) or baseline_price <= 0:
660
  logger.warning(
@@ -750,8 +800,13 @@ def format_prediction(
750
  "quantiles_log": {f"q{q:.2f}": float(bounded_pred[0, i]) for i, q in enumerate(q_list)},
751
  "raw_quantiles": {f"q{q:.2f}": float(raw_pred[0, i]) for i, q in enumerate(q_list)},
752
  "quantile_crossing_detected": quantile_crossing_detected,
753
- "quantile_crossing_rate": quantile_crossing_rate,
754
- "median_sort_gap": median_sort_gap,
 
 
 
 
 
755
  "weekly_return": log_to_simple_return(weekly_log_return),
756
  "weekly_log_return": weekly_log_return,
757
  "weekly_price": _price(weekly_log_return),
 
15
  from typing import Any, Dict, Optional, Sequence
16
 
17
  import torch
18
+ import torch.nn.functional as F
19
  import numpy as np
20
 
21
  from deep_learning.contract import RETURN_SPACE, log_to_simple_return
22
  from deep_learning.config import TFTASROConfig, get_tft_config
23
+ from deep_learning.models.monotonic_quantiles import (
24
+ enforce_monotonic_quantiles,
25
+ validate_monotonicity,
26
+ )
27
  from deep_learning.models.losses import (
28
  AdaptiveSharpeRatioLoss,
29
  CombinedQuantileLoss,
 
136
  def __init__(
137
  self,
138
  quantiles: list,
139
+ lambda_weekly_quantile: float = 0.55,
140
+ lambda_t1_quantile: float = 0.15,
141
+ lambda_dispersion: float = 0.20,
142
  lambda_directional: float = 0.10,
143
+ sharpe_eps: float = 1e-8,
144
+ debug_mode: bool = False,
 
 
 
 
 
 
 
145
  ):
146
  super().__init__(quantiles=quantiles)
147
  self.lambda_weekly_quantile = lambda_weekly_quantile
148
  self.lambda_t1_quantile = lambda_t1_quantile
149
+ self.lambda_dispersion = lambda_dispersion
150
  self.lambda_directional = lambda_directional
 
 
 
 
 
 
151
  self.sharpe_eps = sharpe_eps
152
+ self.debug_mode = debug_mode
 
153
  self.median_idx = len(quantiles) // 2
154
+ self.reset_component_accumulators()
155
+
156
+ def reset_component_accumulators(self) -> None:
157
+ self._component_sums = {
158
+ "weekly_q": 0.0,
159
+ "t1_q": 0.0,
160
+ "dispersion": 0.0,
161
+ "directional": 0.0,
162
+ "total": 0.0,
163
+ }
164
+ self._component_batches = 0
165
+
166
+ def _record_components(
167
+ self,
168
+ weekly_q_loss: torch.Tensor,
169
+ t1_q_loss: torch.Tensor,
170
+ dispersion_loss: torch.Tensor,
171
+ directional_loss: torch.Tensor,
172
+ total_loss: torch.Tensor,
173
+ ) -> None:
174
+ self._component_sums["weekly_q"] += float(weekly_q_loss.detach().mean().cpu())
175
+ self._component_sums["t1_q"] += float(t1_q_loss.detach().mean().cpu())
176
+ self._component_sums["dispersion"] += float(dispersion_loss.detach().mean().cpu())
177
+ self._component_sums["directional"] += float(directional_loss.detach().mean().cpu())
178
+ self._component_sums["total"] += float(total_loss.detach().mean().cpu())
179
+ self._component_batches += 1
180
+
181
+ def component_means(self) -> dict:
182
+ n_batches = self._component_batches
183
+ if n_batches <= 0:
184
+ return {
185
+ "n_batches": 0,
186
+ "weekly_q_loss_mean": 0.0,
187
+ "t1_q_loss_mean": 0.0,
188
+ "dispersion_loss_mean": 0.0,
189
+ "directional_loss_mean": 0.0,
190
+ "total_loss_mean": 0.0,
191
+ "dominant_component": None,
192
+ }
193
+
194
+ components = {
195
+ "weekly_q": self._component_sums["weekly_q"],
196
+ "t1_q": self._component_sums["t1_q"],
197
+ "dispersion": self._component_sums["dispersion"],
198
+ "directional": self._component_sums["directional"],
199
+ }
200
+ return {
201
+ "n_batches": n_batches,
202
+ "weekly_q_loss_mean": self._component_sums["weekly_q"] / n_batches,
203
+ "t1_q_loss_mean": self._component_sums["t1_q"] / n_batches,
204
+ "dispersion_loss_mean": self._component_sums["dispersion"] / n_batches,
205
+ "directional_loss_mean": self._component_sums["directional"] / n_batches,
206
+ "total_loss_mean": self._component_sums["total"] / n_batches,
207
+ "dominant_component": max(components, key=components.get),
208
+ }
209
 
210
  def _pinball(self, pred: torch.Tensor, actual: torch.Tensor) -> torch.Tensor:
211
  q = torch.tensor(self.quantiles, device=pred.device, dtype=pred.dtype).view(1, -1)
 
221
  y_actual = y_actual.float()
222
  y_pred = y_pred.float()
223
 
224
+ ordered_pred = enforce_monotonic_quantiles(
225
+ y_pred,
226
+ median_idx=self.median_idx,
227
+ min_gap=1e-5,
228
+ gap_scale=0.01,
229
+ init_bias=-3.0,
230
+ )
231
+ if self.debug_mode:
232
+ ordered_diagnostics = validate_monotonicity(ordered_pred)
233
+ assert ordered_diagnostics["is_valid"], (
234
+ f"Monotonic transform produced crossings: "
235
+ f"rate={ordered_diagnostics['crossing_rate']}, "
236
+ f"max_violation={ordered_diagnostics['max_violation']}"
237
+ )
238
+ assert torch.allclose(
239
+ ordered_pred[..., self.median_idx],
240
+ y_pred[..., self.median_idx],
241
+ rtol=1e-6,
242
+ atol=1e-7,
243
+ ), "Monotonic transform must preserve the median quantile exactly"
244
+
245
+ median_path = ordered_pred[..., self.median_idx]
246
+ pred_weekly_quantiles = ordered_pred.sum(dim=1)
247
  actual_weekly = y_actual.sum(dim=1)
248
 
249
  weekly_q_loss = self._pinball(pred_weekly_quantiles, actual_weekly)
250
+ t1_q_loss = super().loss(ordered_pred[:, 0:1, :], y_actual[:, 0:1])
251
 
252
  pred_weekly_median = median_path.sum(dim=1)
253
+ eps = self.sharpe_eps
254
+ pred_std = pred_weekly_median.std() + eps
255
+ actual_std = actual_weekly.std() + eps
256
+ dispersion_loss = torch.abs(torch.log(pred_std / actual_std))
257
 
258
+ pred_abs_med = pred_weekly_median.abs().median() + eps
259
+ actual_abs_med = actual_weekly.abs().median() + eps
260
+ magnitude_loss = torch.abs(torch.log(pred_abs_med / actual_abs_med))
261
+ combined_calibration_loss = 0.5 * dispersion_loss + 0.5 * magnitude_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
+ pred_direction = torch.tanh(median_path * 10.0)
264
+ actual_direction = torch.sign(y_actual)
265
+ directional_loss = F.mse_loss(pred_direction, actual_direction.float())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  def _to_scalar(x: torch.Tensor) -> torch.Tensor:
268
  # pytorch_forecasting metrics can return per-sample tensors;
 
270
  # boolean comparisons in tests and stable optimizer behaviour.
271
  return x.mean() if x.ndim > 0 else x
272
 
273
+ weekly_q_loss = _to_scalar(weekly_q_loss)
274
+ t1_q_loss = _to_scalar(t1_q_loss)
275
+ combined_calibration_loss = _to_scalar(combined_calibration_loss)
276
+ directional_loss = _to_scalar(directional_loss)
277
+ total_loss = (
278
  self.lambda_weekly_quantile * _to_scalar(weekly_q_loss)
279
  + self.lambda_t1_quantile * _to_scalar(t1_q_loss)
280
+ + self.lambda_dispersion * _to_scalar(combined_calibration_loss)
281
+ + self.lambda_directional * _to_scalar(directional_loss)
 
 
 
 
 
282
  )
283
+ self._record_components(
284
+ weekly_q_loss,
285
+ t1_q_loss,
286
+ combined_calibration_loss,
287
+ directional_loss,
288
+ total_loss,
289
+ )
290
+ return total_loss
291
 
292
  except ImportError:
293
  ASROPFLoss = None # type: ignore[assignment,misc]
 
323
  quantiles=quantiles,
324
  lambda_weekly_quantile=cfg.weekly_loss.lambda_weekly_quantile,
325
  lambda_t1_quantile=cfg.weekly_loss.lambda_t1_quantile,
326
+ lambda_dispersion=cfg.weekly_loss.lambda_dispersion,
327
  lambda_directional=cfg.weekly_loss.lambda_directional,
 
 
 
 
 
 
328
  )
329
  logger.info(
330
+ "Using weekly ASRO loss | weekly_q=%.2f t1_q=%.2f dispersion=%.2f dir=%.2f monotonic_transform=true",
331
  cfg.weekly_loss.lambda_weekly_quantile,
332
  cfg.weekly_loss.lambda_t1_quantile,
333
+ cfg.weekly_loss.lambda_dispersion,
334
  cfg.weekly_loss.lambda_directional,
 
 
 
 
 
 
335
  )
336
  elif use_asro and ASROPFLoss is not None:
337
  loss = ASROPFLoss(
 
508
  quantile_diffs = np.diff(raw_pred, axis=-1) if raw_pred.shape[-1] > 1 else np.array([])
509
  crossing_mask = quantile_diffs < -1e-12 if quantile_diffs.size else np.array([], dtype=bool)
510
  quantile_crossing_detected = bool(crossing_mask.any())
511
+ raw_quantile_crossing_rate = float(crossing_mask.mean()) if crossing_mask.size else 0.0
512
+ ordered_tensor = enforce_monotonic_quantiles(
513
+ torch.as_tensor(raw_pred, dtype=torch.float64),
514
+ median_idx=median_idx,
515
+ min_gap=1e-5,
516
+ gap_scale=0.01,
517
+ init_bias=-3.0,
518
+ )
519
+ pred = ordered_tensor.detach().cpu().numpy()
520
+ quantile_crossing_rate = 0.0
521
+ sorted_pred = pred
522
  median_sort_gap = float(
523
  np.max(np.abs(raw_pred[..., median_idx] - sorted_pred[..., median_idx]))
524
  )
525
  if quantile_crossing_detected:
526
  logger.error(
527
  "format_prediction: non-monotonic quantiles detected "
528
+ "(raw_crossing_rate=%.3f, max_median_sort_gap=%.4f); public output "
529
+ "uses the structural monotonic transform and exposes raw_quantiles for audit.",
530
+ raw_quantile_crossing_rate,
531
  median_sort_gap,
532
  )
 
533
 
534
  if _math.isnan(baseline_price) or _math.isinf(baseline_price) or baseline_price <= 0:
535
  logger.warning(
 
669
  quantile_diffs = np.diff(raw_pred, axis=-1) if raw_pred.shape[-1] > 1 else np.array([])
670
  crossing_mask = quantile_diffs < -1e-12 if quantile_diffs.size else np.array([], dtype=bool)
671
  quantile_crossing_detected = bool(crossing_mask.any())
672
+ raw_quantile_crossing_rate = float(crossing_mask.mean()) if crossing_mask.size else 0.0
673
+ ordered_tensor = enforce_monotonic_quantiles(
674
+ torch.as_tensor(raw_pred, dtype=torch.float64),
675
+ median_idx=median_idx,
676
+ min_gap=1e-5,
677
+ gap_scale=0.01,
678
+ init_bias=-3.0,
679
+ )
680
+ pred = ordered_tensor.detach().cpu().numpy()
681
+ ordered_diffs = np.diff(pred, axis=-1) if pred.shape[-1] > 1 else np.array([])
682
+ ordered_crossing_mask = (
683
+ ordered_diffs < -1e-12 if ordered_diffs.size else np.array([], dtype=bool)
684
+ )
685
+ ordered_quantile_crossing_rate = (
686
+ float(ordered_crossing_mask.mean()) if ordered_crossing_mask.size else 0.0
687
+ )
688
+ if ordered_quantile_crossing_rate > 0.0:
689
+ raise AssertionError(
690
+ "Monotonic quantile transform produced public crossings: "
691
+ f"{ordered_quantile_crossing_rate:.6f}"
692
+ )
693
+ sorted_raw = np.sort(raw_pred, axis=-1)
694
+ raw_median_sort_gap = float(
695
+ np.max(np.abs(raw_pred[..., median_idx] - sorted_raw[..., median_idx]))
696
+ )
697
+ ordered_median_sort_gap = float(
698
+ np.max(np.abs(pred[..., median_idx] - np.sort(pred, axis=-1)[..., median_idx]))
699
+ )
700
  if quantile_crossing_detected:
701
  logger.error(
702
  "format_prediction: non-monotonic quantiles detected "
703
+ "(raw_crossing_rate=%.3f, raw_max_median_sort_gap=%.4f); public output "
704
+ "uses the structural monotonic transform and exposes raw_quantiles for audit.",
705
+ raw_quantile_crossing_rate,
706
+ raw_median_sort_gap,
707
  )
 
708
 
709
  if _math.isnan(baseline_price) or _math.isinf(baseline_price) or baseline_price <= 0:
710
  logger.warning(
 
800
  "quantiles_log": {f"q{q:.2f}": float(bounded_pred[0, i]) for i, q in enumerate(q_list)},
801
  "raw_quantiles": {f"q{q:.2f}": float(raw_pred[0, i]) for i, q in enumerate(q_list)},
802
  "quantile_crossing_detected": quantile_crossing_detected,
803
+ "quantile_crossing_rate": ordered_quantile_crossing_rate,
804
+ "raw_quantile_crossing_rate": raw_quantile_crossing_rate,
805
+ "ordered_quantile_crossing_rate": ordered_quantile_crossing_rate,
806
+ "public_quantile_crossing_rate": ordered_quantile_crossing_rate,
807
+ "median_sort_gap": ordered_median_sort_gap,
808
+ "raw_median_sort_gap": raw_median_sort_gap,
809
+ "ordered_median_sort_gap": ordered_median_sort_gap,
810
  "weekly_return": log_to_simple_return(weekly_log_return),
811
  "weekly_log_return": weekly_log_return,
812
  "weekly_price": _price(weekly_log_return),
deep_learning/training/callbacks.py CHANGED
@@ -81,6 +81,47 @@ class CurriculumLossScheduler(pl.Callback):
81
  )
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  class SWACallback(pl.Callback):
85
  """
86
  Stochastic Weight Averaging over the last ``swa_pct`` of training.
 
81
  )
82
 
83
 
84
+ class WeeklyLossComponentLogger(pl.Callback):
85
+ """Log weekly loss component scales at validation epoch boundaries."""
86
+
87
+ def on_validation_epoch_start(self, trainer, pl_module):
88
+ loss = getattr(pl_module, "loss", None)
89
+ if hasattr(loss, "reset_component_accumulators"):
90
+ loss.reset_component_accumulators()
91
+
92
+ def on_validation_epoch_end(self, trainer, pl_module):
93
+ loss = getattr(pl_module, "loss", None)
94
+ if not hasattr(loss, "component_means"):
95
+ return
96
+
97
+ stats = loss.component_means()
98
+ if not stats.get("n_batches"):
99
+ return
100
+
101
+ epoch = getattr(trainer, "current_epoch", 0)
102
+ logger.info(
103
+ "Weekly loss components | epoch=%s weekly_q=%.6f t1_q=%.6f "
104
+ "dispersion=%.6f directional=%.6f total=%.6f dominant=%s",
105
+ epoch,
106
+ stats["weekly_q_loss_mean"],
107
+ stats["t1_q_loss_mean"],
108
+ stats["dispersion_loss_mean"],
109
+ stats["directional_loss_mean"],
110
+ stats["total_loss_mean"],
111
+ stats["dominant_component"],
112
+ )
113
+ if stats["dispersion_loss_mean"] > 3.0 * max(stats["weekly_q_loss_mean"], 1e-12):
114
+ logger.warning(
115
+ "Weekly dispersion loss is dominating weekly quantile loss; "
116
+ "lambda_dispersion may need to be reduced."
117
+ )
118
+ if stats["directional_loss_mean"] < 0.05 * max(stats["total_loss_mean"], 1e-12):
119
+ logger.warning(
120
+ "Weekly directional loss is below 5%% of total loss; "
121
+ "lambda_directional may need to increase."
122
+ )
123
+
124
+
125
  class SWACallback(pl.Callback):
126
  """
127
  Stochastic Weight Averaging over the last ``swa_pct`` of training.
deep_learning/training/hyperopt.py CHANGED
@@ -13,6 +13,7 @@ from __future__ import annotations
13
  import argparse
14
  import json
15
  import logging
 
16
  import warnings
17
  from dataclasses import replace
18
  from pathlib import Path
@@ -37,6 +38,16 @@ from deep_learning.config import (
37
  )
38
  from deep_learning.logging_utils import configure_cli_logging, suppress_lightning_noise
39
 
 
 
 
 
 
 
 
 
 
 
40
  logger = logging.getLogger(__name__)
41
 
42
  MIN_COMPLETED_TRIALS = 10
@@ -55,15 +66,10 @@ KNOWN_GOOD_TRIAL_PARAMS = {
55
  "lambda_vol": 0.30,
56
  "lambda_quantile": 0.25,
57
  "lambda_madl": 0.40,
58
- "lambda_weekly_quantile": 0.60,
59
- "lambda_t1_quantile": 0.10,
 
60
  "lambda_directional": 0.10,
61
- "lambda_magnitude": 0.55,
62
- "weekly_lambda_vol": 0.35,
63
- "lambda_width": 0.50,
64
- "lambda_tail_width": 0.30,
65
- "lambda_sanity": 0.20,
66
- "lambda_crossing": 7.0,
67
  "batch_size": 32,
68
  }
69
 
@@ -144,7 +150,9 @@ def _build_prune_diagnostics(study) -> tuple[dict[str, int], list[dict]]:
144
  "avg_variance_ratio",
145
  "avg_directional_accuracy",
146
  "avg_val_sharpe",
 
147
  "avg_quantile_crossing_rate",
 
148
  "avg_median_sort_gap",
149
  "avg_weekly_magnitude_ratio",
150
  "avg_weekly_pi80_coverage",
@@ -180,6 +188,8 @@ def _build_result_payload(study) -> dict:
180
  trial_state_counts = _trial_state_counts(study)
181
  best = _best_finite_completed_trial(study)
182
  prune_reasons, fold_diagnostics = _build_prune_diagnostics(study)
 
 
183
 
184
  if best is None:
185
  return {
@@ -191,6 +201,9 @@ def _build_result_payload(study) -> dict:
191
  "trial_state_counts": trial_state_counts,
192
  "prune_reasons": prune_reasons,
193
  "fold_diagnostics": fold_diagnostics,
 
 
 
194
  "message": (
195
  "No Optuna trials completed with a finite objective value; "
196
  "final training will use the known-good fallback config "
@@ -198,6 +211,11 @@ def _build_result_payload(study) -> dict:
198
  ),
199
  }
200
 
 
 
 
 
 
201
  return {
202
  "status": "completed",
203
  "best_trial": best.number,
@@ -207,6 +225,9 @@ def _build_result_payload(study) -> dict:
207
  "trial_state_counts": trial_state_counts,
208
  "prune_reasons": prune_reasons,
209
  "fold_diagnostics": fold_diagnostics,
 
 
 
210
  }
211
 
212
 
@@ -266,15 +287,10 @@ def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
266
  )
267
 
268
  weekly_loss_cfg = WeeklyLossConfig(
269
- lambda_weekly_quantile=trial.suggest_float("lambda_weekly_quantile", 0.60, 0.75, step=0.05),
270
- lambda_t1_quantile=trial.suggest_float("lambda_t1_quantile", 0.05, 0.15, step=0.05),
 
271
  lambda_directional=trial.suggest_float("lambda_directional", 0.05, 0.12, step=0.01),
272
- lambda_magnitude=trial.suggest_float("lambda_magnitude", 0.50, 0.80, step=0.05),
273
- lambda_vol=trial.suggest_float("weekly_lambda_vol", 0.25, 0.45, step=0.05),
274
- lambda_crossing=trial.suggest_float("lambda_crossing", 5.0, 10.0, step=1.0),
275
- lambda_sanity=trial.suggest_float("lambda_sanity", 0.10, 0.30, step=0.05),
276
- lambda_width=trial.suggest_float("lambda_width", 0.40, 0.90, step=0.05),
277
- lambda_tail_width=trial.suggest_float("lambda_tail_width", 0.25, 0.75, step=0.05),
278
  )
279
 
280
  training_cfg = TrainingConfig(
@@ -343,6 +359,7 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
343
  from deep_learning.training.callbacks import CurriculumLossScheduler
344
  from deep_learning.training.metrics import (
345
  compute_weekly_metrics,
 
346
  quantile_crossing_rate,
347
  quantile_median_sort_gap,
348
  select_prediction_horizon,
@@ -367,7 +384,9 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
367
  fold_sharpe_list: list[float] = []
368
  fold_vr_list: list[float] = []
369
  fold_crossing_list: list[float] = []
 
370
  fold_median_gap_list: list[float] = []
 
371
  fold_weekly_objectives: list[float] = []
372
  fold_weekly_mr_list: list[float] = []
373
  fold_weekly_pi80_coverage_list: list[float] = []
@@ -469,10 +488,14 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
469
  raise ValueError(
470
  f"Prediction horizon too short: {pred_np.shape[1]} < {trial_cfg.forecast.primary_horizon_days}"
471
  )
472
- pred_t1 = pred_np[:, 0, :]
 
 
473
  y_pred = pred_t1[:, median_idx]
474
  fold_crossing_rate = quantile_crossing_rate(pred_t1)
 
475
  _, fold_median_gap = quantile_median_sort_gap(pred_t1, median_idx)
 
476
 
477
  y_actual_parts = []
478
  for batch in fold_val_dl:
@@ -510,7 +533,7 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
510
  )
511
  weekly_pinball = _weekly_pinball_loss(
512
  y_actual_path[:n_path],
513
- pred_np[:n_path],
514
  tuple(trial_cfg.model.quantiles),
515
  horizon=trial_cfg.forecast.primary_horizon_days,
516
  )
@@ -518,9 +541,9 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
518
  fold_weekly_pi80_coverage = float(weekly.get("weekly_pi80_coverage", 0.0))
519
  fold_weekly_pi80_width_ratio = float(weekly.get("weekly_pi80_width_ratio", 1.0))
520
  fold_weekly_pi96_width_ratio = float(weekly.get("weekly_pi96_width_ratio", 1.0))
521
- fold_weekly_raw_crossing = float(weekly.get("weekly_quantile_crossing_rate", 0.0))
522
  fold_weekly_sorted_crossing = float(
523
- weekly.get("weekly_sorted_quantile_crossing_rate", 0.0)
524
  )
525
  fold_weekly_interval_score_80 = float(weekly.get("weekly_interval_score_80", 0.0))
526
  fold_weekly_interval_score_96 = float(weekly.get("weekly_interval_score_96", 0.0))
@@ -530,7 +553,6 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
530
  coverage_penalty = abs(fold_weekly_pi80_coverage - 0.80)
531
  width_penalty = max(0.0, fold_weekly_pi80_width_ratio - 1.5)
532
  tail_width_penalty = max(0.0, fold_weekly_pi96_width_ratio - 3.0)
533
- raw_crossing_penalty = max(0.0, fold_weekly_raw_crossing - 0.05)
534
  fold_weekly_objective = (
535
  0.35 * weekly_pinball
536
  + 0.15 * (1.0 - float(weekly.get("weekly_directional_accuracy", 0.5)))
@@ -540,7 +562,6 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
540
  + 0.35 * tail_width_penalty
541
  + 0.10 * interval_score_penalty
542
  + 0.05 * interval_score_96_penalty
543
- + 0.50 * raw_crossing_penalty
544
  + 0.25 * fold_weekly_sorted_crossing
545
  )
546
  except Exception as exc:
@@ -553,7 +574,9 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
553
  fold_da_list.append(fold_da)
554
  fold_sharpe_list.append(fold_sharpe)
555
  fold_crossing_list.append(fold_crossing_rate)
 
556
  fold_median_gap_list.append(fold_median_gap)
 
557
  fold_weekly_objectives.append(fold_weekly_objective)
558
  fold_weekly_mr_list.append(fold_weekly_mr)
559
  fold_weekly_pi80_coverage_list.append(fold_weekly_pi80_coverage)
@@ -671,7 +694,13 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
671
  avg_sharpe = float(np.mean(fold_sharpe_list)) if fold_sharpe_list else 0.0
672
  avg_vr = float(np.mean(fold_vr_list)) if fold_vr_list else 0.0
673
  avg_crossing = float(np.mean(fold_crossing_list)) if fold_crossing_list else 0.0
 
 
 
674
  avg_median_gap = float(np.mean(fold_median_gap_list)) if fold_median_gap_list else 0.0
 
 
 
675
  avg_weekly_mr = float(np.mean(fold_weekly_mr_list)) if fold_weekly_mr_list else 1.0
676
  avg_weekly_pi80_coverage = (
677
  float(np.mean(fold_weekly_pi80_coverage_list))
@@ -717,7 +746,9 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
717
  trial.set_user_attr("avg_variance_ratio", round(avg_vr, 4))
718
  trial.set_user_attr("avg_directional_accuracy", round(avg_da, 4))
719
  trial.set_user_attr("avg_val_sharpe", round(avg_sharpe, 4))
 
720
  trial.set_user_attr("avg_quantile_crossing_rate", round(avg_crossing, 4))
 
721
  trial.set_user_attr("avg_median_sort_gap", round(avg_median_gap, 4))
722
  trial.set_user_attr("avg_weekly_magnitude_ratio", round(avg_weekly_mr, 4))
723
  trial.set_user_attr("avg_weekly_pi80_coverage", round(avg_weekly_pi80_coverage, 4))
@@ -741,21 +772,11 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
741
  trial.set_user_attr("prune_reason", "sharpe_prune")
742
  raise optuna.exceptions.TrialPruned()
743
 
744
- if (avg_crossing > 0.20 or avg_median_gap > 0.01) and not protect_trial:
745
- logger.warning(
746
- "Trial %d PRUNED: quantile incoherence crossing=%.3f median_gap=%.4f",
747
- trial.number, avg_crossing, avg_median_gap,
748
- )
749
- trial.set_user_attr("prune_reason", "crossing_prune")
750
- raise optuna.exceptions.TrialPruned()
751
-
752
- if (avg_weekly_raw_crossing > 0.05 or avg_weekly_sorted_crossing > 0.0) and not protect_trial:
753
- logger.warning(
754
- "Trial %d PRUNED: weekly quantile incoherence raw=%.3f sorted=%.3f",
755
- trial.number, avg_weekly_raw_crossing, avg_weekly_sorted_crossing,
756
  )
757
- trial.set_user_attr("prune_reason", "weekly_raw_crossing_prune")
758
- raise optuna.exceptions.TrialPruned()
759
 
760
  # Soft penalty: avg DA below coin-flip
761
  da_penalty = 2.0 * max(0.0, 0.50 - avg_da) if avg_da < 0.50 else 0.0
@@ -825,6 +846,15 @@ def run_hyperopt(
825
  results_path.parent.mkdir(parents=True, exist_ok=True)
826
  result = _build_result_payload(study)
827
  results_path.write_text(json.dumps(result, indent=2, allow_nan=False))
 
 
 
 
 
 
 
 
 
828
 
829
  if result["best_trial"] is None:
830
  logger.warning(
@@ -841,6 +871,10 @@ def run_hyperopt(
841
  )
842
  logger.info("Best params: %s", result["best_params"])
843
 
 
 
 
 
844
  return result
845
 
846
 
 
13
  import argparse
14
  import json
15
  import logging
16
+ import sys
17
  import warnings
18
  from dataclasses import replace
19
  from pathlib import Path
 
38
  )
39
  from deep_learning.logging_utils import configure_cli_logging, suppress_lightning_noise
40
 
41
+ PROJECT_ROOT = Path(__file__).resolve().parents[3]
42
+ if str(PROJECT_ROOT) not in sys.path:
43
+ sys.path.insert(0, str(PROJECT_ROOT))
44
+
45
+ from scripts.hyperopt_diagnostics import (
46
+ best_trial_preflight_check,
47
+ compute_structural_invalidity_report,
48
+ compute_trial_distribution_summary,
49
+ )
50
+
51
  logger = logging.getLogger(__name__)
52
 
53
  MIN_COMPLETED_TRIALS = 10
 
66
  "lambda_vol": 0.30,
67
  "lambda_quantile": 0.25,
68
  "lambda_madl": 0.40,
69
+ "lambda_weekly_quantile": 0.55,
70
+ "lambda_t1_quantile": 0.15,
71
+ "lambda_dispersion": 0.20,
72
  "lambda_directional": 0.10,
 
 
 
 
 
 
73
  "batch_size": 32,
74
  }
75
 
 
150
  "avg_variance_ratio",
151
  "avg_directional_accuracy",
152
  "avg_val_sharpe",
153
+ "avg_raw_quantile_crossing_rate",
154
  "avg_quantile_crossing_rate",
155
+ "avg_raw_median_sort_gap",
156
  "avg_median_sort_gap",
157
  "avg_weekly_magnitude_ratio",
158
  "avg_weekly_pi80_coverage",
 
188
  trial_state_counts = _trial_state_counts(study)
189
  best = _best_finite_completed_trial(study)
190
  prune_reasons, fold_diagnostics = _build_prune_diagnostics(study)
191
+ structural_report = compute_structural_invalidity_report(fold_diagnostics)
192
+ distribution_summary = compute_trial_distribution_summary(fold_diagnostics)
193
 
194
  if best is None:
195
  return {
 
201
  "trial_state_counts": trial_state_counts,
202
  "prune_reasons": prune_reasons,
203
  "fold_diagnostics": fold_diagnostics,
204
+ "structural_invalidity_report": structural_report,
205
+ "trial_distribution_summary": distribution_summary,
206
+ "best_trial_preflight": None,
207
  "message": (
208
  "No Optuna trials completed with a finite objective value; "
209
  "final training will use the known-good fallback config "
 
211
  ),
212
  }
213
 
214
+ best_diagnostics = next(
215
+ (d for d in fold_diagnostics if d.get("trial") == best.number),
216
+ {},
217
+ )
218
+ preflight = best_trial_preflight_check(best_diagnostics)
219
  return {
220
  "status": "completed",
221
  "best_trial": best.number,
 
225
  "trial_state_counts": trial_state_counts,
226
  "prune_reasons": prune_reasons,
227
  "fold_diagnostics": fold_diagnostics,
228
+ "structural_invalidity_report": structural_report,
229
+ "trial_distribution_summary": distribution_summary,
230
+ "best_trial_preflight": preflight,
231
  }
232
 
233
 
 
287
  )
288
 
289
  weekly_loss_cfg = WeeklyLossConfig(
290
+ lambda_weekly_quantile=trial.suggest_float("lambda_weekly_quantile", 0.45, 0.65, step=0.05),
291
+ lambda_t1_quantile=trial.suggest_float("lambda_t1_quantile", 0.05, 0.20, step=0.05),
292
+ lambda_dispersion=trial.suggest_float("lambda_dispersion", 0.15, 0.35, step=0.05),
293
  lambda_directional=trial.suggest_float("lambda_directional", 0.05, 0.12, step=0.01),
 
 
 
 
 
 
294
  )
295
 
296
  training_cfg = TrainingConfig(
 
359
  from deep_learning.training.callbacks import CurriculumLossScheduler
360
  from deep_learning.training.metrics import (
361
  compute_weekly_metrics,
362
+ monotonic_quantiles_np,
363
  quantile_crossing_rate,
364
  quantile_median_sort_gap,
365
  select_prediction_horizon,
 
384
  fold_sharpe_list: list[float] = []
385
  fold_vr_list: list[float] = []
386
  fold_crossing_list: list[float] = []
387
+ fold_raw_crossing_list: list[float] = []
388
  fold_median_gap_list: list[float] = []
389
+ fold_raw_median_gap_list: list[float] = []
390
  fold_weekly_objectives: list[float] = []
391
  fold_weekly_mr_list: list[float] = []
392
  fold_weekly_pi80_coverage_list: list[float] = []
 
488
  raise ValueError(
489
  f"Prediction horizon too short: {pred_np.shape[1]} < {trial_cfg.forecast.primary_horizon_days}"
490
  )
491
+ ordered_pred_np = monotonic_quantiles_np(pred_np, median_idx=median_idx)
492
+ raw_pred_t1 = pred_np[:, 0, :]
493
+ pred_t1 = ordered_pred_np[:, 0, :]
494
  y_pred = pred_t1[:, median_idx]
495
  fold_crossing_rate = quantile_crossing_rate(pred_t1)
496
+ fold_raw_crossing_rate = quantile_crossing_rate(raw_pred_t1)
497
  _, fold_median_gap = quantile_median_sort_gap(pred_t1, median_idx)
498
+ _, fold_raw_median_gap = quantile_median_sort_gap(raw_pred_t1, median_idx)
499
 
500
  y_actual_parts = []
501
  for batch in fold_val_dl:
 
533
  )
534
  weekly_pinball = _weekly_pinball_loss(
535
  y_actual_path[:n_path],
536
+ ordered_pred_np[:n_path],
537
  tuple(trial_cfg.model.quantiles),
538
  horizon=trial_cfg.forecast.primary_horizon_days,
539
  )
 
541
  fold_weekly_pi80_coverage = float(weekly.get("weekly_pi80_coverage", 0.0))
542
  fold_weekly_pi80_width_ratio = float(weekly.get("weekly_pi80_width_ratio", 1.0))
543
  fold_weekly_pi96_width_ratio = float(weekly.get("weekly_pi96_width_ratio", 1.0))
544
+ fold_weekly_raw_crossing = float(weekly.get("weekly_raw_quantile_crossing_rate", 0.0))
545
  fold_weekly_sorted_crossing = float(
546
+ weekly.get("weekly_ordered_quantile_crossing_rate", 0.0)
547
  )
548
  fold_weekly_interval_score_80 = float(weekly.get("weekly_interval_score_80", 0.0))
549
  fold_weekly_interval_score_96 = float(weekly.get("weekly_interval_score_96", 0.0))
 
553
  coverage_penalty = abs(fold_weekly_pi80_coverage - 0.80)
554
  width_penalty = max(0.0, fold_weekly_pi80_width_ratio - 1.5)
555
  tail_width_penalty = max(0.0, fold_weekly_pi96_width_ratio - 3.0)
 
556
  fold_weekly_objective = (
557
  0.35 * weekly_pinball
558
  + 0.15 * (1.0 - float(weekly.get("weekly_directional_accuracy", 0.5)))
 
562
  + 0.35 * tail_width_penalty
563
  + 0.10 * interval_score_penalty
564
  + 0.05 * interval_score_96_penalty
 
565
  + 0.25 * fold_weekly_sorted_crossing
566
  )
567
  except Exception as exc:
 
574
  fold_da_list.append(fold_da)
575
  fold_sharpe_list.append(fold_sharpe)
576
  fold_crossing_list.append(fold_crossing_rate)
577
+ fold_raw_crossing_list.append(fold_raw_crossing_rate)
578
  fold_median_gap_list.append(fold_median_gap)
579
+ fold_raw_median_gap_list.append(fold_raw_median_gap)
580
  fold_weekly_objectives.append(fold_weekly_objective)
581
  fold_weekly_mr_list.append(fold_weekly_mr)
582
  fold_weekly_pi80_coverage_list.append(fold_weekly_pi80_coverage)
 
694
  avg_sharpe = float(np.mean(fold_sharpe_list)) if fold_sharpe_list else 0.0
695
  avg_vr = float(np.mean(fold_vr_list)) if fold_vr_list else 0.0
696
  avg_crossing = float(np.mean(fold_crossing_list)) if fold_crossing_list else 0.0
697
+ avg_raw_crossing = (
698
+ float(np.mean(fold_raw_crossing_list)) if fold_raw_crossing_list else 0.0
699
+ )
700
  avg_median_gap = float(np.mean(fold_median_gap_list)) if fold_median_gap_list else 0.0
701
+ avg_raw_median_gap = (
702
+ float(np.mean(fold_raw_median_gap_list)) if fold_raw_median_gap_list else 0.0
703
+ )
704
  avg_weekly_mr = float(np.mean(fold_weekly_mr_list)) if fold_weekly_mr_list else 1.0
705
  avg_weekly_pi80_coverage = (
706
  float(np.mean(fold_weekly_pi80_coverage_list))
 
746
  trial.set_user_attr("avg_variance_ratio", round(avg_vr, 4))
747
  trial.set_user_attr("avg_directional_accuracy", round(avg_da, 4))
748
  trial.set_user_attr("avg_val_sharpe", round(avg_sharpe, 4))
749
+ trial.set_user_attr("avg_raw_quantile_crossing_rate", round(avg_raw_crossing, 4))
750
  trial.set_user_attr("avg_quantile_crossing_rate", round(avg_crossing, 4))
751
+ trial.set_user_attr("avg_raw_median_sort_gap", round(avg_raw_median_gap, 4))
752
  trial.set_user_attr("avg_median_sort_gap", round(avg_median_gap, 4))
753
  trial.set_user_attr("avg_weekly_magnitude_ratio", round(avg_weekly_mr, 4))
754
  trial.set_user_attr("avg_weekly_pi80_coverage", round(avg_weekly_pi80_coverage, 4))
 
772
  trial.set_user_attr("prune_reason", "sharpe_prune")
773
  raise optuna.exceptions.TrialPruned()
774
 
775
+ if avg_crossing > 0.001 or avg_weekly_sorted_crossing > 0.001:
776
+ raise RuntimeError(
777
+ "Monotonic quantile transform produced public crossings: "
778
+ f"daily={avg_crossing:.6f}, weekly={avg_weekly_sorted_crossing:.6f}"
 
 
 
 
 
 
 
 
779
  )
 
 
780
 
781
  # Soft penalty: avg DA below coin-flip
782
  da_penalty = 2.0 * max(0.0, 0.50 - avg_da) if avg_da < 0.50 else 0.0
 
846
  results_path.parent.mkdir(parents=True, exist_ok=True)
847
  result = _build_result_payload(study)
848
  results_path.write_text(json.dumps(result, indent=2, allow_nan=False))
849
+ logger.info(
850
+ "Optuna structural invalidity report: %s",
851
+ result.get("structural_invalidity_report"),
852
+ )
853
+ logger.info(
854
+ "Optuna trial distribution summary: %s",
855
+ result.get("trial_distribution_summary"),
856
+ )
857
+ logger.info("Optuna best trial preflight: %s", result.get("best_trial_preflight"))
858
 
859
  if result["best_trial"] is None:
860
  logger.warning(
 
871
  )
872
  logger.info("Best params: %s", result["best_params"])
873
 
874
+ structural_report = result.get("structural_invalidity_report") or {}
875
+ if structural_report.get("verdict") == "STRUCTURAL_FAILURE":
876
+ raise RuntimeError(structural_report.get("next_action", "Structural failure in hyperopt."))
877
+
878
  return result
879
 
880
 
deep_learning/training/metrics.py CHANGED
@@ -13,6 +13,9 @@ from __future__ import annotations
13
 
14
  import numpy as np
15
  import pandas as pd
 
 
 
16
 
17
 
18
  def select_prediction_horizon(values: np.ndarray, horizon_idx: int = 0) -> np.ndarray:
@@ -53,6 +56,27 @@ def cumulative_quantiles(pred: np.ndarray, horizon: int = 5) -> np.ndarray:
53
  return arr[:, :horizon, :].sum(axis=1)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def magnitude_ratio(y_actual: np.ndarray, y_pred: np.ndarray) -> float:
57
  """Median predicted absolute move divided by median actual absolute move."""
58
  denom = np.median(np.abs(np.asarray(y_actual, dtype=np.float64)))
@@ -288,18 +312,24 @@ def compute_all_metrics(
288
 
289
  if y_pred_quantiles is not None:
290
  q_arr = np.asarray(y_pred_quantiles, dtype=np.float64)
291
- sorted_q = np.sort(q_arr, axis=-1)
292
  raw_crossing = quantile_crossing_rate(q_arr)
293
- sorted_crossing = quantile_crossing_rate(sorted_q)
294
- metrics["quantile_crossing_rate"] = raw_crossing
295
  metrics["raw_quantile_crossing_rate"] = raw_crossing
296
- metrics["sorted_quantile_crossing_rate"] = sorted_crossing
 
 
297
  gap_mean, gap_max = quantile_median_sort_gap(q_arr)
298
- metrics["median_sort_gap_mean"] = gap_mean
299
- metrics["median_sort_gap_max"] = gap_max
300
- sorted_gap_mean, sorted_gap_max = quantile_median_sort_gap(sorted_q)
301
- metrics["sorted_median_sort_gap_mean"] = sorted_gap_mean
302
- metrics["sorted_median_sort_gap_max"] = sorted_gap_max
 
 
 
 
303
 
304
  return metrics
305
 
@@ -317,8 +347,10 @@ def compute_weekly_metrics(
317
  to simple returns happens only during inference formatting.
318
  """
319
  weekly_actual = cumulative_horizon(y_actual_path, horizon=horizon)
320
- approx_weekly_quantiles = cumulative_quantiles(y_pred_quantiles_path, horizon=horizon)
321
- weekly_quantiles = np.sort(approx_weekly_quantiles, axis=-1)
 
 
322
 
323
  median_idx = len(quantiles) // 2
324
  q10_idx = quantiles.index(0.10)
@@ -340,18 +372,27 @@ def compute_weekly_metrics(
340
  y_pred_q90=weekly_quantiles[:, q90_idx],
341
  y_pred_q02=weekly_quantiles[:, q02_idx],
342
  y_pred_q98=weekly_quantiles[:, q98_idx],
343
- y_pred_quantiles=approx_weekly_quantiles,
344
  tail_threshold=tail_threshold,
345
  )
346
 
347
  weekly_metrics = {f"weekly_{k}": v for k, v in metrics.items()}
348
  weekly_metrics["weekly_interval_quantile_source"] = 1.0
349
  weekly_metrics["weekly_approx_quantile_crossing_rate"] = quantile_crossing_rate(
350
- approx_weekly_quantiles
351
  )
352
- approx_gap_mean, approx_gap_max = quantile_median_sort_gap(approx_weekly_quantiles)
353
  weekly_metrics["weekly_approx_median_sort_gap_mean"] = approx_gap_mean
354
  weekly_metrics["weekly_approx_median_sort_gap_max"] = approx_gap_max
 
 
 
 
 
 
 
 
 
355
  weekly_metrics["weekly_magnitude_ratio"] = magnitude_ratio(weekly_actual, weekly_pred)
356
  weekly_metrics["weekly_mean_actual_abs"] = float(np.mean(np.abs(weekly_actual)))
357
  weekly_metrics["weekly_mean_pred_abs"] = float(np.mean(np.abs(weekly_pred)))
 
13
 
14
  import numpy as np
15
  import pandas as pd
16
+ import torch
17
+
18
+ from deep_learning.models.monotonic_quantiles import enforce_monotonic_quantiles
19
 
20
 
21
  def select_prediction_horizon(values: np.ndarray, horizon_idx: int = 0) -> np.ndarray:
 
56
  return arr[:, :horizon, :].sum(axis=1)
57
 
58
 
59
+ def monotonic_quantiles_np(
60
+ pred: np.ndarray,
61
+ median_idx: int | None = None,
62
+ ) -> np.ndarray:
63
+ """Apply the production monotonic quantile transform to a numpy tensor."""
64
+ arr = np.asarray(pred, dtype=np.float64)
65
+ if arr.shape[-1] == 0:
66
+ return arr.copy()
67
+ if median_idx is None:
68
+ median_idx = arr.shape[-1] // 2
69
+ tensor = torch.as_tensor(arr, dtype=torch.float64)
70
+ ordered = enforce_monotonic_quantiles(
71
+ tensor,
72
+ median_idx=median_idx,
73
+ min_gap=1e-5,
74
+ gap_scale=0.01,
75
+ init_bias=-3.0,
76
+ )
77
+ return ordered.detach().cpu().numpy()
78
+
79
+
80
  def magnitude_ratio(y_actual: np.ndarray, y_pred: np.ndarray) -> float:
81
  """Median predicted absolute move divided by median actual absolute move."""
82
  denom = np.median(np.abs(np.asarray(y_actual, dtype=np.float64)))
 
312
 
313
  if y_pred_quantiles is not None:
314
  q_arr = np.asarray(y_pred_quantiles, dtype=np.float64)
315
+ ordered_q = monotonic_quantiles_np(q_arr)
316
  raw_crossing = quantile_crossing_rate(q_arr)
317
+ ordered_crossing = quantile_crossing_rate(ordered_q)
318
+ metrics["quantile_crossing_rate"] = ordered_crossing
319
  metrics["raw_quantile_crossing_rate"] = raw_crossing
320
+ metrics["ordered_quantile_crossing_rate"] = ordered_crossing
321
+ metrics["public_quantile_crossing_rate"] = ordered_crossing
322
+ metrics["sorted_quantile_crossing_rate"] = ordered_crossing
323
  gap_mean, gap_max = quantile_median_sort_gap(q_arr)
324
+ metrics["raw_median_sort_gap_mean"] = gap_mean
325
+ metrics["raw_median_sort_gap_max"] = gap_max
326
+ ordered_gap_mean, ordered_gap_max = quantile_median_sort_gap(ordered_q)
327
+ metrics["median_sort_gap_mean"] = ordered_gap_mean
328
+ metrics["median_sort_gap_max"] = ordered_gap_max
329
+ metrics["ordered_median_sort_gap_mean"] = ordered_gap_mean
330
+ metrics["ordered_median_sort_gap_max"] = ordered_gap_max
331
+ metrics["sorted_median_sort_gap_mean"] = ordered_gap_mean
332
+ metrics["sorted_median_sort_gap_max"] = ordered_gap_max
333
 
334
  return metrics
335
 
 
347
  to simple returns happens only during inference formatting.
348
  """
349
  weekly_actual = cumulative_horizon(y_actual_path, horizon=horizon)
350
+ raw_path = np.asarray(y_pred_quantiles_path, dtype=np.float64)
351
+ ordered_path = monotonic_quantiles_np(raw_path, median_idx=len(quantiles) // 2)
352
+ raw_weekly_quantiles = cumulative_quantiles(raw_path, horizon=horizon)
353
+ weekly_quantiles = cumulative_quantiles(ordered_path, horizon=horizon)
354
 
355
  median_idx = len(quantiles) // 2
356
  q10_idx = quantiles.index(0.10)
 
372
  y_pred_q90=weekly_quantiles[:, q90_idx],
373
  y_pred_q02=weekly_quantiles[:, q02_idx],
374
  y_pred_q98=weekly_quantiles[:, q98_idx],
375
+ y_pred_quantiles=weekly_quantiles,
376
  tail_threshold=tail_threshold,
377
  )
378
 
379
  weekly_metrics = {f"weekly_{k}": v for k, v in metrics.items()}
380
  weekly_metrics["weekly_interval_quantile_source"] = 1.0
381
  weekly_metrics["weekly_approx_quantile_crossing_rate"] = quantile_crossing_rate(
382
+ raw_weekly_quantiles
383
  )
384
+ approx_gap_mean, approx_gap_max = quantile_median_sort_gap(raw_weekly_quantiles)
385
  weekly_metrics["weekly_approx_median_sort_gap_mean"] = approx_gap_mean
386
  weekly_metrics["weekly_approx_median_sort_gap_max"] = approx_gap_max
387
+ weekly_metrics["weekly_raw_quantile_crossing_rate"] = quantile_crossing_rate(
388
+ raw_weekly_quantiles
389
+ )
390
+ weekly_metrics["weekly_ordered_quantile_crossing_rate"] = quantile_crossing_rate(
391
+ weekly_quantiles
392
+ )
393
+ weekly_metrics["weekly_public_quantile_crossing_rate"] = weekly_metrics[
394
+ "weekly_ordered_quantile_crossing_rate"
395
+ ]
396
  weekly_metrics["weekly_magnitude_ratio"] = magnitude_ratio(weekly_actual, weekly_pred)
397
  weekly_metrics["weekly_mean_actual_abs"] = float(np.mean(np.abs(weekly_actual)))
398
  weekly_metrics["weekly_mean_pred_abs"] = float(np.mean(np.abs(weekly_pred)))
deep_learning/training/trainer.py CHANGED
@@ -47,7 +47,7 @@ warnings.filterwarnings(
47
  logger = logging.getLogger(__name__)
48
 
49
  KNOWN_GOOD_CONFIG = {
50
- "max_encoder_length": 60,
51
  "hidden_size": 48,
52
  "attention_head_size": 2,
53
  "dropout": 0.30,
@@ -57,18 +57,15 @@ KNOWN_GOOD_CONFIG = {
57
  "lambda_vol": 0.30,
58
  "lambda_quantile": 0.25,
59
  "lambda_madl": 0.40,
60
- "lambda_weekly_quantile": 0.60,
61
- "lambda_t1_quantile": 0.10,
 
62
  "lambda_directional": 0.10,
63
- "lambda_magnitude": 0.55,
64
- "weekly_lambda_vol": 0.35,
65
- "lambda_width": 0.50,
66
- "lambda_tail_width": 0.30,
67
- "lambda_sanity": 0.20,
68
- "lambda_crossing": 7.0,
69
  "batch_size": 32,
70
  }
71
 
 
 
72
  REQUIRED_PROMOTABLE_METRICS = (
73
  "weekly_directional_accuracy",
74
  "weekly_magnitude_ratio",
@@ -126,13 +123,22 @@ def _compute_test_metrics_from_quantiles(
126
  pred_np: np.ndarray,
127
  cfg: TFTASROConfig,
128
  ) -> dict[str, float]:
129
- from deep_learning.training.metrics import compute_all_metrics, compute_weekly_metrics, select_prediction_horizon
 
 
 
 
 
 
 
130
 
131
  pred_np = np.asarray(pred_np)
132
  _validate_quantile_prediction_shape(pred_np, cfg)
133
 
134
  median_idx = len(cfg.model.quantiles) // 2
135
- pred_t1 = pred_np[:, 0, :]
 
 
136
  y_pred_median = pred_t1[:, median_idx]
137
  y_pred_q10 = pred_t1[:, 1]
138
  y_pred_q90 = pred_t1[:, -2]
@@ -150,6 +156,10 @@ def _compute_test_metrics_from_quantiles(
150
  y_pred_q98=y_pred_q98[:n],
151
  y_pred_quantiles=pred_t1[:n],
152
  )
 
 
 
 
153
 
154
  n_path = min(len(y_actual_path), len(pred_np))
155
  weekly_metrics = compute_weekly_metrics(
@@ -167,6 +177,7 @@ def train_tft_model(
167
  cfg: Optional[TFTASROConfig] = None,
168
  use_asro: bool = True,
169
  upload_to_hub: bool = False,
 
170
  ) -> dict:
171
  """
172
  End-to-end TFT-ASRO training.
@@ -189,16 +200,23 @@ def train_tft_model(
189
  from deep_learning.data.feature_store import build_tft_dataframe
190
  from deep_learning.data.dataset import build_datasets, create_dataloaders
191
  from deep_learning.models.tft_copper import create_tft_model, get_variable_importance, format_prediction
192
- from deep_learning.training.callbacks import CurriculumLossScheduler, SWACallback
 
 
 
 
193
 
194
  if cfg is None:
195
  cfg = get_tft_config()
196
 
197
- # ---- 0a. Load Optuna best params if available ----
198
- # When the hyperopt step ran before this trainer, it writes best params to
199
- # optuna_results.json. We apply those params over the default config so that
200
- # the final training run actually benefits from the search.
201
- cfg = _apply_optuna_results(cfg)
 
 
 
202
 
203
  # ---- 0b. ASRO loss sanity check (runs before any training) ----
204
  try:
@@ -260,19 +278,11 @@ def train_tft_model(
260
  cfg.training.early_stopping_patience,
261
  )
262
  logger.info(
263
- "Weekly loss | weekly_q=%.2f t1_q=%.2f directional=%.2f magnitude=%.2f vol=%.2f",
264
  cfg.weekly_loss.lambda_weekly_quantile,
265
  cfg.weekly_loss.lambda_t1_quantile,
 
266
  cfg.weekly_loss.lambda_directional,
267
- cfg.weekly_loss.lambda_magnitude,
268
- cfg.weekly_loss.lambda_vol,
269
- )
270
- logger.info(
271
- "Weekly guards | width=%.2f tail_width=%.2f crossing=%.2f sanity=%.2f",
272
- cfg.weekly_loss.lambda_width,
273
- cfg.weekly_loss.lambda_tail_width,
274
- cfg.weekly_loss.lambda_crossing,
275
- cfg.weekly_loss.lambda_sanity,
276
  )
277
  else:
278
  logger.info(
@@ -308,6 +318,7 @@ def train_tft_model(
308
  save_top_k=3,
309
  save_last=True,
310
  ),
 
311
  ]
312
 
313
  if use_asro and cfg.forecast.primary_horizon_days != 5:
@@ -430,12 +441,8 @@ def train_tft_model(
430
  "lambda_weekly_quantile": cfg.weekly_loss.lambda_weekly_quantile,
431
  "lambda_t1_quantile": cfg.weekly_loss.lambda_t1_quantile,
432
  "lambda_directional": cfg.weekly_loss.lambda_directional,
433
- "lambda_magnitude": cfg.weekly_loss.lambda_magnitude,
434
- "weekly_lambda_vol": cfg.weekly_loss.lambda_vol,
435
- "weekly_lambda_crossing": cfg.weekly_loss.lambda_crossing,
436
- "lambda_sanity": cfg.weekly_loss.lambda_sanity,
437
- "lambda_width": cfg.weekly_loss.lambda_width,
438
- "lambda_tail_width": cfg.weekly_loss.lambda_tail_width,
439
  "max_encoder_length": cfg.model.max_encoder_length,
440
  "max_prediction_length": cfg.model.max_prediction_length,
441
  "forecast_contract_version": FORECAST_CONTRACT_VERSION,
@@ -518,7 +525,11 @@ def _write_conformal_calibration_artifact(
518
  import torch
519
 
520
  from deep_learning.calibration.conformal import rolling_conformal_adjustment
521
- from deep_learning.training.metrics import cumulative_horizon, cumulative_quantiles
 
 
 
 
522
 
523
  y_parts = []
524
  for batch in val_dl:
@@ -534,9 +545,13 @@ def _write_conformal_calibration_artifact(
534
  return None
535
 
536
  weekly_actual = cumulative_horizon(y_actual_path[:n], horizon=cfg.forecast.primary_horizon_days)
537
- weekly_quantiles = np.sort(
538
- cumulative_quantiles(pred_np[:n], horizon=cfg.forecast.primary_horizon_days),
539
- axis=-1,
 
 
 
 
540
  )
541
  q = tuple(cfg.model.quantiles)
542
  q10_idx = q.index(0.10)
@@ -638,18 +653,10 @@ def _apply_optuna_results(cfg: TFTASROConfig) -> TFTASROConfig:
638
  params["learning_rate"] = min(float(params["learning_rate"]), 6e-4)
639
  if "weight_decay" in params:
640
  params["weight_decay"] = min(float(params["weight_decay"]), 5e-4)
641
- if "lambda_magnitude" in params:
642
- params["lambda_magnitude"] = max(float(params["lambda_magnitude"]), 0.50)
643
  if "lambda_directional" in params:
644
  params["lambda_directional"] = min(float(params["lambda_directional"]), 0.12)
645
- if "lambda_width" in params:
646
- params["lambda_width"] = max(float(params["lambda_width"]), 0.40)
647
- if "lambda_tail_width" in params:
648
- params["lambda_tail_width"] = max(float(params["lambda_tail_width"]), 0.25)
649
- if "lambda_sanity" in params:
650
- params["lambda_sanity"] = max(float(params["lambda_sanity"]), 0.10)
651
- if "lambda_crossing" in params:
652
- params["lambda_crossing"] = max(float(params["lambda_crossing"]), 5.0)
653
 
654
  logger.info(
655
  "Loaded Optuna best params (trial #%d, weekly_objective=%.4f): %s",
@@ -685,12 +692,9 @@ def _overlay_training_config(cfg: TFTASROConfig, params: dict) -> TFTASROConfig:
685
  weekly_loss_overrides = {
686
  k: params[k] for k in (
687
  "lambda_weekly_quantile", "lambda_t1_quantile", "lambda_directional",
688
- "lambda_magnitude", "lambda_crossing", "lambda_sanity",
689
- "lambda_width", "lambda_tail_width",
690
  ) if k in params
691
  }
692
- if "weekly_lambda_vol" in params:
693
- weekly_loss_overrides["lambda_vol"] = params["weekly_lambda_vol"]
694
 
695
  new_model = replace(cfg.model, **model_overrides) if model_overrides else cfg.model
696
  new_asro = replace(cfg.asro, **asro_overrides) if asro_overrides else cfg.asro
@@ -744,10 +748,20 @@ if __name__ == "__main__":
744
  parser.add_argument("--symbol", default="HG=F")
745
  parser.add_argument("--no-asro", action="store_true", help="Use standard QuantileLoss instead of ASRO")
746
  parser.add_argument("--upload-hub", action="store_true", help="Upload artifacts to HF Hub after training")
 
 
 
 
 
747
  args = parser.parse_args()
748
 
749
  cfg = get_tft_config()
750
- result = train_tft_model(cfg, use_asro=not args.no_asro, upload_to_hub=args.upload_hub)
 
 
 
 
 
751
 
752
  print("\n" + "=" * 60)
753
  print("TFT-ASRO TRAINING COMPLETE")
 
47
  logger = logging.getLogger(__name__)
48
 
49
  KNOWN_GOOD_CONFIG = {
50
+ "max_encoder_length": 50,
51
  "hidden_size": 48,
52
  "attention_head_size": 2,
53
  "dropout": 0.30,
 
57
  "lambda_vol": 0.30,
58
  "lambda_quantile": 0.25,
59
  "lambda_madl": 0.40,
60
+ "lambda_weekly_quantile": 0.55,
61
+ "lambda_t1_quantile": 0.15,
62
+ "lambda_dispersion": 0.20,
63
  "lambda_directional": 0.10,
 
 
 
 
 
 
64
  "batch_size": 32,
65
  }
66
 
67
+ DETERMINISTIC_WEEKLY_CONFIG = dict(KNOWN_GOOD_CONFIG)
68
+
69
  REQUIRED_PROMOTABLE_METRICS = (
70
  "weekly_directional_accuracy",
71
  "weekly_magnitude_ratio",
 
123
  pred_np: np.ndarray,
124
  cfg: TFTASROConfig,
125
  ) -> dict[str, float]:
126
+ from deep_learning.training.metrics import (
127
+ compute_all_metrics,
128
+ compute_weekly_metrics,
129
+ monotonic_quantiles_np,
130
+ quantile_crossing_rate,
131
+ quantile_median_sort_gap,
132
+ select_prediction_horizon,
133
+ )
134
 
135
  pred_np = np.asarray(pred_np)
136
  _validate_quantile_prediction_shape(pred_np, cfg)
137
 
138
  median_idx = len(cfg.model.quantiles) // 2
139
+ ordered_pred_np = monotonic_quantiles_np(pred_np, median_idx=median_idx)
140
+ raw_pred_t1 = pred_np[:, 0, :]
141
+ pred_t1 = ordered_pred_np[:, 0, :]
142
  y_pred_median = pred_t1[:, median_idx]
143
  y_pred_q10 = pred_t1[:, 1]
144
  y_pred_q90 = pred_t1[:, -2]
 
156
  y_pred_q98=y_pred_q98[:n],
157
  y_pred_quantiles=pred_t1[:n],
158
  )
159
+ raw_gap_mean, raw_gap_max = quantile_median_sort_gap(raw_pred_t1[:n], median_idx)
160
+ test_metrics["raw_quantile_crossing_rate"] = quantile_crossing_rate(raw_pred_t1[:n])
161
+ test_metrics["raw_median_sort_gap_mean"] = raw_gap_mean
162
+ test_metrics["raw_median_sort_gap_max"] = raw_gap_max
163
 
164
  n_path = min(len(y_actual_path), len(pred_np))
165
  weekly_metrics = compute_weekly_metrics(
 
177
  cfg: Optional[TFTASROConfig] = None,
178
  use_asro: bool = True,
179
  upload_to_hub: bool = False,
180
+ deterministic_weekly_validation: bool = False,
181
  ) -> dict:
182
  """
183
  End-to-end TFT-ASRO training.
 
200
  from deep_learning.data.feature_store import build_tft_dataframe
201
  from deep_learning.data.dataset import build_datasets, create_dataloaders
202
  from deep_learning.models.tft_copper import create_tft_model, get_variable_importance, format_prediction
203
+ from deep_learning.training.callbacks import (
204
+ CurriculumLossScheduler,
205
+ SWACallback,
206
+ WeeklyLossComponentLogger,
207
+ )
208
 
209
  if cfg is None:
210
  cfg = get_tft_config()
211
 
212
+ # ---- 0a. Load training params ----
213
+ # Deterministic validation bypasses Optuna so structural changes can be
214
+ # measured before investing in search.
215
+ if deterministic_weekly_validation:
216
+ cfg = _overlay_training_config(cfg, DETERMINISTIC_WEEKLY_CONFIG)
217
+ logger.info("Using deterministic weekly validation config: %s", DETERMINISTIC_WEEKLY_CONFIG)
218
+ else:
219
+ cfg = _apply_optuna_results(cfg)
220
 
221
  # ---- 0b. ASRO loss sanity check (runs before any training) ----
222
  try:
 
278
  cfg.training.early_stopping_patience,
279
  )
280
  logger.info(
281
+ "Weekly loss | weekly_q=%.2f t1_q=%.2f dispersion=%.2f directional=%.2f monotonic_transform=true",
282
  cfg.weekly_loss.lambda_weekly_quantile,
283
  cfg.weekly_loss.lambda_t1_quantile,
284
+ cfg.weekly_loss.lambda_dispersion,
285
  cfg.weekly_loss.lambda_directional,
 
 
 
 
 
 
 
 
 
286
  )
287
  else:
288
  logger.info(
 
318
  save_top_k=3,
319
  save_last=True,
320
  ),
321
+ WeeklyLossComponentLogger(),
322
  ]
323
 
324
  if use_asro and cfg.forecast.primary_horizon_days != 5:
 
441
  "lambda_weekly_quantile": cfg.weekly_loss.lambda_weekly_quantile,
442
  "lambda_t1_quantile": cfg.weekly_loss.lambda_t1_quantile,
443
  "lambda_directional": cfg.weekly_loss.lambda_directional,
444
+ "lambda_dispersion": cfg.weekly_loss.lambda_dispersion,
445
+ "monotonic_quantile_transform": True,
 
 
 
 
446
  "max_encoder_length": cfg.model.max_encoder_length,
447
  "max_prediction_length": cfg.model.max_prediction_length,
448
  "forecast_contract_version": FORECAST_CONTRACT_VERSION,
 
525
  import torch
526
 
527
  from deep_learning.calibration.conformal import rolling_conformal_adjustment
528
+ from deep_learning.training.metrics import (
529
+ cumulative_horizon,
530
+ cumulative_quantiles,
531
+ monotonic_quantiles_np,
532
+ )
533
 
534
  y_parts = []
535
  for batch in val_dl:
 
545
  return None
546
 
547
  weekly_actual = cumulative_horizon(y_actual_path[:n], horizon=cfg.forecast.primary_horizon_days)
548
+ ordered_pred_np = monotonic_quantiles_np(
549
+ pred_np[:n],
550
+ median_idx=len(cfg.model.quantiles) // 2,
551
+ )
552
+ weekly_quantiles = cumulative_quantiles(
553
+ ordered_pred_np,
554
+ horizon=cfg.forecast.primary_horizon_days,
555
  )
556
  q = tuple(cfg.model.quantiles)
557
  q10_idx = q.index(0.10)
 
653
  params["learning_rate"] = min(float(params["learning_rate"]), 6e-4)
654
  if "weight_decay" in params:
655
  params["weight_decay"] = min(float(params["weight_decay"]), 5e-4)
 
 
656
  if "lambda_directional" in params:
657
  params["lambda_directional"] = min(float(params["lambda_directional"]), 0.12)
658
+ if "lambda_dispersion" in params:
659
+ params["lambda_dispersion"] = max(float(params["lambda_dispersion"]), 0.20)
 
 
 
 
 
 
660
 
661
  logger.info(
662
  "Loaded Optuna best params (trial #%d, weekly_objective=%.4f): %s",
 
692
  weekly_loss_overrides = {
693
  k: params[k] for k in (
694
  "lambda_weekly_quantile", "lambda_t1_quantile", "lambda_directional",
695
+ "lambda_dispersion",
 
696
  ) if k in params
697
  }
 
 
698
 
699
  new_model = replace(cfg.model, **model_overrides) if model_overrides else cfg.model
700
  new_asro = replace(cfg.asro, **asro_overrides) if asro_overrides else cfg.asro
 
748
  parser.add_argument("--symbol", default="HG=F")
749
  parser.add_argument("--no-asro", action="store_true", help="Use standard QuantileLoss instead of ASRO")
750
  parser.add_argument("--upload-hub", action="store_true", help="Upload artifacts to HF Hub after training")
751
+ parser.add_argument(
752
+ "--deterministic-weekly-validation",
753
+ action="store_true",
754
+ help="Bypass Optuna overlays and run the fixed monotonic weekly validation config",
755
+ )
756
  args = parser.parse_args()
757
 
758
  cfg = get_tft_config()
759
+ result = train_tft_model(
760
+ cfg,
761
+ use_asro=not args.no_asro,
762
+ upload_to_hub=args.upload_hub,
763
+ deterministic_weekly_validation=args.deterministic_weekly_validation,
764
+ )
765
 
766
  print("\n" + "=" * 60)
767
  print("TFT-ASRO TRAINING COMPLETE")
scripts/tft_quality_gate.py CHANGED
@@ -19,7 +19,7 @@ BACKEND_ROOT = pathlib.Path(__file__).resolve().parents[1]
19
  if str(BACKEND_ROOT) not in sys.path:
20
  sys.path.insert(0, str(BACKEND_ROOT))
21
 
22
- from app.quality_gate import evaluate_quality_gate
23
 
24
  META_PATH = pathlib.Path(os.environ.get("TFT_METADATA_PATH", "/tmp/models/tft/tft_metadata.json"))
25
 
@@ -37,17 +37,23 @@ def main() -> int:
37
  tail_capture = metrics.get("tail_capture_rate")
38
  quantile_crossing = metrics.get("quantile_crossing_rate")
39
  median_gap_max = metrics.get("median_sort_gap_max")
 
 
 
40
  weekly_da = metrics.get("weekly_directional_accuracy")
41
  weekly_mr = metrics.get("weekly_magnitude_ratio")
42
  weekly_tail = metrics.get("weekly_tail_capture_rate")
43
  weekly_pi80 = metrics.get("weekly_pi80_coverage")
 
44
  weekly_pi80_width_ratio = metrics.get("weekly_pi80_width_ratio")
45
  weekly_pi96 = metrics.get("weekly_pi96_coverage")
 
46
  weekly_pi96_width_ratio = metrics.get("weekly_pi96_width_ratio")
47
  weekly_qcross = metrics.get("weekly_quantile_crossing_rate")
48
  weekly_sorted_qcross = metrics.get("weekly_sorted_quantile_crossing_rate")
49
  weekly_gap = metrics.get("weekly_median_sort_gap_max")
50
  weekly_samples = metrics.get("weekly_sample_count")
 
51
 
52
  print(
53
  "Quality gate metrics: "
@@ -71,18 +77,29 @@ def main() -> int:
71
  tail_capture=tail_capture,
72
  quantile_crossing_rate=quantile_crossing,
73
  median_sort_gap_max=median_gap_max,
 
 
74
  weekly_directional_accuracy=weekly_da,
75
  weekly_magnitude_ratio=weekly_mr,
76
  weekly_tail_capture_rate=weekly_tail,
77
  weekly_pi80_coverage=weekly_pi80,
 
78
  weekly_pi80_width_ratio=weekly_pi80_width_ratio,
79
  weekly_pi96_coverage=weekly_pi96,
 
80
  weekly_pi96_width_ratio=weekly_pi96_width_ratio,
81
  weekly_quantile_crossing_rate=weekly_qcross,
82
  weekly_sorted_quantile_crossing_rate=weekly_sorted_qcross,
83
  weekly_median_sort_gap_max=weekly_gap,
84
  weekly_sample_count=weekly_samples,
85
  )
 
 
 
 
 
 
 
86
 
87
  if passed:
88
  print("QUALITY GATE: PASSED")
 
19
  if str(BACKEND_ROOT) not in sys.path:
20
  sys.path.insert(0, str(BACKEND_ROOT))
21
 
22
+ from app.quality_gate import evaluate_quality_gate, evaluate_quality_gate_warnings
23
 
24
  META_PATH = pathlib.Path(os.environ.get("TFT_METADATA_PATH", "/tmp/models/tft/tft_metadata.json"))
25
 
 
37
  tail_capture = metrics.get("tail_capture_rate")
38
  quantile_crossing = metrics.get("quantile_crossing_rate")
39
  median_gap_max = metrics.get("median_sort_gap_max")
40
+ pi80_width = metrics.get("pi80_width")
41
+ pi96_width = metrics.get("pi96_width")
42
+ mae_vs_naive_zero = metrics.get("mae_vs_naive_zero")
43
  weekly_da = metrics.get("weekly_directional_accuracy")
44
  weekly_mr = metrics.get("weekly_magnitude_ratio")
45
  weekly_tail = metrics.get("weekly_tail_capture_rate")
46
  weekly_pi80 = metrics.get("weekly_pi80_coverage")
47
+ weekly_pi80_width = metrics.get("weekly_pi80_width")
48
  weekly_pi80_width_ratio = metrics.get("weekly_pi80_width_ratio")
49
  weekly_pi96 = metrics.get("weekly_pi96_coverage")
50
+ weekly_pi96_width = metrics.get("weekly_pi96_width")
51
  weekly_pi96_width_ratio = metrics.get("weekly_pi96_width_ratio")
52
  weekly_qcross = metrics.get("weekly_quantile_crossing_rate")
53
  weekly_sorted_qcross = metrics.get("weekly_sorted_quantile_crossing_rate")
54
  weekly_gap = metrics.get("weekly_median_sort_gap_max")
55
  weekly_samples = metrics.get("weekly_sample_count")
56
+ weekly_mae_vs_naive_zero = metrics.get("weekly_mae_vs_naive_zero")
57
 
58
  print(
59
  "Quality gate metrics: "
 
77
  tail_capture=tail_capture,
78
  quantile_crossing_rate=quantile_crossing,
79
  median_sort_gap_max=median_gap_max,
80
+ pi80_width=pi80_width,
81
+ pi96_width=pi96_width,
82
  weekly_directional_accuracy=weekly_da,
83
  weekly_magnitude_ratio=weekly_mr,
84
  weekly_tail_capture_rate=weekly_tail,
85
  weekly_pi80_coverage=weekly_pi80,
86
+ weekly_pi80_width=weekly_pi80_width,
87
  weekly_pi80_width_ratio=weekly_pi80_width_ratio,
88
  weekly_pi96_coverage=weekly_pi96,
89
+ weekly_pi96_width=weekly_pi96_width,
90
  weekly_pi96_width_ratio=weekly_pi96_width_ratio,
91
  weekly_quantile_crossing_rate=weekly_qcross,
92
  weekly_sorted_quantile_crossing_rate=weekly_sorted_qcross,
93
  weekly_median_sort_gap_max=weekly_gap,
94
  weekly_sample_count=weekly_samples,
95
  )
96
+ warnings = evaluate_quality_gate_warnings(
97
+ vr=vr,
98
+ mae_vs_naive_zero=mae_vs_naive_zero,
99
+ weekly_mae_vs_naive_zero=weekly_mae_vs_naive_zero,
100
+ )
101
+ for warning in warnings:
102
+ print(f"QUALITY GATE WARNING: {warning}")
103
 
104
  if passed:
105
  print("QUALITY GATE: PASSED")