ifieryarrows commited on
Commit
4d13aee
·
verified ·
1 Parent(s): 2e5fce9

Sync from GitHub (tests passed)

Browse files
app/quality_gate.py CHANGED
@@ -62,7 +62,9 @@ def evaluate_quality_gate(
62
  elif weekly_pi80_coverage < 0.74 or weekly_pi80_coverage > 0.86:
63
  reasons.append(f"WeeklyPI80={weekly_pi80_coverage:.4f} outside [0.74, 0.86]")
64
 
65
- if weekly_quantile_crossing_rate is not None and weekly_quantile_crossing_rate > 0.10:
 
 
66
  reasons.append(f"WeeklyQuantileCrossing={weekly_quantile_crossing_rate:.4f} > 0.10")
67
 
68
  if weekly_median_sort_gap_max is not None and weekly_median_sort_gap_max > 0.005:
@@ -76,7 +78,9 @@ def evaluate_quality_gate(
76
  reasons.append(f"VR={vr:.4f} outside [0.2, 2.5]")
77
  if tail_capture is not None and tail_capture < 0.35:
78
  reasons.append(f"TailCapture={tail_capture:.4f} < 0.35")
79
- if quantile_crossing_rate is not None and quantile_crossing_rate > 0.20:
 
 
80
  reasons.append(f"QuantileCrossing={quantile_crossing_rate:.4f} > 0.20")
81
  if median_sort_gap_max is not None and median_sort_gap_max > 0.01:
82
  reasons.append(f"MedianSortGapMax={median_sort_gap_max:.4f} > 0.01")
 
62
  elif weekly_pi80_coverage < 0.74 or weekly_pi80_coverage > 0.86:
63
  reasons.append(f"WeeklyPI80={weekly_pi80_coverage:.4f} outside [0.74, 0.86]")
64
 
65
+ if weekly_quantile_crossing_rate is None:
66
+ reasons.append("Missing weekly_quantile_crossing_rate")
67
+ elif weekly_quantile_crossing_rate > 0.10:
68
  reasons.append(f"WeeklyQuantileCrossing={weekly_quantile_crossing_rate:.4f} > 0.10")
69
 
70
  if weekly_median_sort_gap_max is not None and weekly_median_sort_gap_max > 0.005:
 
78
  reasons.append(f"VR={vr:.4f} outside [0.2, 2.5]")
79
  if tail_capture is not None and tail_capture < 0.35:
80
  reasons.append(f"TailCapture={tail_capture:.4f} < 0.35")
81
+ if quantile_crossing_rate is None:
82
+ reasons.append("Missing quantile_crossing_rate")
83
+ elif quantile_crossing_rate > 0.20:
84
  reasons.append(f"QuantileCrossing={quantile_crossing_rate:.4f} > 0.20")
85
  if median_sort_gap_max is not None and median_sort_gap_max > 0.01:
86
  reasons.append(f"MedianSortGapMax={median_sort_gap_max:.4f} > 0.01")
deep_learning/training/hyperopt.py CHANGED
@@ -93,9 +93,14 @@ def _finite_completed_trial_count(study) -> int:
93
  )
94
 
95
 
96
- def _weekly_pinball_loss(actual_path: np.ndarray, pred_path: np.ndarray, quantiles: tuple[float, ...]) -> float:
97
- actual = np.asarray(actual_path, dtype=np.float64)[:, :5].sum(axis=1)
98
- pred = np.asarray(pred_path, dtype=np.float64)[:, :5, :].sum(axis=1)
 
 
 
 
 
99
  q = np.asarray(quantiles, dtype=np.float64).reshape(1, -1)
100
  err = actual.reshape(-1, 1) - pred
101
  return float(np.maximum(q * err, (q - 1.0) * err).mean())
@@ -447,11 +452,13 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
447
  y_actual_path[:n_path],
448
  pred_np[:n_path],
449
  quantiles=trial_cfg.model.quantiles,
 
450
  )
451
  weekly_pinball = _weekly_pinball_loss(
452
  y_actual_path[:n_path],
453
  pred_np[:n_path],
454
  tuple(trial_cfg.model.quantiles),
 
455
  )
456
  fold_weekly_mr = float(weekly.get("weekly_magnitude_ratio", 1.0))
457
  fold_weekly_objective = (
@@ -648,7 +655,7 @@ def run_hyperopt(
648
  )
649
  else:
650
  logger.info(
651
- "Optuna best trial #%d: val_loss=%.6f",
652
  result["best_trial"],
653
  result["best_value"],
654
  )
@@ -685,6 +692,6 @@ if __name__ == "__main__":
685
  print(f"Trial states: {counts}")
686
  else:
687
  print(f"Best trial: #{result['best_trial']}")
688
- print(f"Best val_loss: {result['best_value']:.6f}")
689
  for k, v in result["best_params"].items():
690
  print(f" {k}: {v}")
 
93
  )
94
 
95
 
96
+ def _weekly_pinball_loss(
97
+ actual_path: np.ndarray,
98
+ pred_path: np.ndarray,
99
+ quantiles: tuple[float, ...],
100
+ horizon: int = 5,
101
+ ) -> float:
102
+ actual = np.asarray(actual_path, dtype=np.float64)[:, :horizon].sum(axis=1)
103
+ pred = np.asarray(pred_path, dtype=np.float64)[:, :horizon, :].sum(axis=1)
104
  q = np.asarray(quantiles, dtype=np.float64).reshape(1, -1)
105
  err = actual.reshape(-1, 1) - pred
106
  return float(np.maximum(q * err, (q - 1.0) * err).mean())
 
452
  y_actual_path[:n_path],
453
  pred_np[:n_path],
454
  quantiles=trial_cfg.model.quantiles,
455
+ horizon=trial_cfg.forecast.primary_horizon_days,
456
  )
457
  weekly_pinball = _weekly_pinball_loss(
458
  y_actual_path[:n_path],
459
  pred_np[:n_path],
460
  tuple(trial_cfg.model.quantiles),
461
+ horizon=trial_cfg.forecast.primary_horizon_days,
462
  )
463
  fold_weekly_mr = float(weekly.get("weekly_magnitude_ratio", 1.0))
464
  fold_weekly_objective = (
 
655
  )
656
  else:
657
  logger.info(
658
+ "Optuna best trial #%d: weekly_objective=%.6f",
659
  result["best_trial"],
660
  result["best_value"],
661
  )
 
692
  print(f"Trial states: {counts}")
693
  else:
694
  print(f"Best trial: #{result['best_trial']}")
695
+ print(f"Best weekly objective: {result['best_value']:.6f}")
696
  for k, v in result["best_params"].items():
697
  print(f" {k}: {v}")
deep_learning/training/metrics.py CHANGED
@@ -238,15 +238,16 @@ def compute_weekly_metrics(
238
  y_actual_path: np.ndarray,
239
  y_pred_quantiles_path: np.ndarray,
240
  quantiles: tuple[float, ...] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98),
 
241
  ) -> dict[str, float]:
242
  """
243
- Compute weekly-first metrics from 5-step daily log-return paths.
244
 
245
  Internal evaluation remains in log-return space. Public API/UI conversion
246
  to simple returns happens only during inference formatting.
247
  """
248
- weekly_actual = cumulative_horizon(y_actual_path, horizon=5)
249
- weekly_quantiles = cumulative_quantiles(y_pred_quantiles_path, horizon=5)
250
 
251
  median_idx = len(quantiles) // 2
252
  q10_idx = quantiles.index(0.10)
 
238
  y_actual_path: np.ndarray,
239
  y_pred_quantiles_path: np.ndarray,
240
  quantiles: tuple[float, ...] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98),
241
+ horizon: int = 5,
242
  ) -> dict[str, float]:
243
  """
244
+ Compute weekly-first metrics from a daily log-return path.
245
 
246
  Internal evaluation remains in log-return space. Public API/UI conversion
247
  to simple returns happens only during inference formatting.
248
  """
249
+ weekly_actual = cumulative_horizon(y_actual_path, horizon=horizon)
250
+ weekly_quantiles = cumulative_quantiles(y_pred_quantiles_path, horizon=horizon)
251
 
252
  median_idx = len(quantiles) // 2
253
  q10_idx = quantiles.index(0.10)
deep_learning/training/trainer.py CHANGED
@@ -58,6 +58,94 @@ KNOWN_GOOD_CONFIG = {
58
  "batch_size": 32,
59
  }
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def train_tft_model(
63
  cfg: Optional[TFTASROConfig] = None,
@@ -84,7 +172,6 @@ def train_tft_model(
84
  from deep_learning.data.feature_store import build_tft_dataframe
85
  from deep_learning.data.dataset import build_datasets, create_dataloaders
86
  from deep_learning.models.tft_copper import create_tft_model, get_variable_importance, format_prediction
87
- from deep_learning.training.metrics import compute_all_metrics, compute_weekly_metrics, select_prediction_horizon
88
  from deep_learning.training.callbacks import CurriculumLossScheduler, SWACallback
89
 
90
  if cfg is None:
@@ -149,14 +236,29 @@ def train_tft_model(
149
  cfg.model.dropout, cfg.model.attention_head_size,
150
  cfg.model.learning_rate, cfg.model.gradient_clip_val,
151
  )
152
- logger.info(
153
- "Training data | samples=%d batch_size=%d batches/epoch=%d "
154
- "patience=%d w_quantile=%.2f w_sharpe=%.2f lambda_vol=%.2f",
155
- len(training_ds), cfg.training.batch_size, n_batches,
156
- cfg.training.early_stopping_patience,
157
- cfg.asro.lambda_quantile, 1.0 - cfg.asro.lambda_quantile,
158
- cfg.asro.lambda_vol,
159
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  logger.info(
161
  "Model params | total=%s trainable=%s",
162
  f"{total_params:,}", f"{trainable_params:,}",
@@ -238,8 +340,6 @@ def train_tft_model(
238
  batch[1][0] if isinstance(batch[1], (list, tuple)) else batch[1]
239
  )
240
  y_actual_path = torch.cat(y_actual_parts).cpu().numpy()
241
- y_actual = select_prediction_horizon(y_actual_path, horizon_idx=0)
242
-
243
  # Gather top-k checkpoint paths
244
  best_k = getattr(trainer.checkpoint_callback, "best_k_models", {})
245
  ckpt_paths = sorted(best_k.keys(), key=lambda p: best_k[p]) if best_k else []
@@ -247,13 +347,8 @@ def train_tft_model(
247
  # Always include the just-trained model as a baseline
248
  all_pred_arrays = []
249
 
250
- def _predict_to_np(mdl):
251
- pred = mdl.predict(test_dl, return_x=True)
252
- pt = pred.output if hasattr(pred, "output") else pred
253
- return pt.cpu().numpy() if hasattr(pt, "cpu") else np.array(pt)
254
-
255
  # Predictions from the best model (already in memory)
256
- all_pred_arrays.append(_predict_to_np(model))
257
 
258
  # Load additional checkpoints for ensemble
259
  for cp in ckpt_paths:
@@ -261,10 +356,10 @@ def train_tft_model(
261
  continue # already have this one
262
  try:
263
  ckpt_model = load_tft_model(str(cp))
264
- all_pred_arrays.append(_predict_to_np(ckpt_model))
265
  del ckpt_model
266
  except Exception as exc:
267
- logger.debug("Skipping ensemble checkpoint %s: %s", cp, exc)
268
 
269
  ensemble_size = len(all_pred_arrays)
270
  logger.info(
@@ -277,40 +372,12 @@ def train_tft_model(
277
  else:
278
  pred_np = all_pred_arrays[0]
279
 
280
- median_idx = len(cfg.model.quantiles) // 2
281
- if pred_np.ndim == 3:
282
- pred_t1 = pred_np[:, 0, :]
283
- y_pred_median = pred_t1[:, median_idx]
284
- y_pred_q10 = pred_t1[:, 1] if pred_t1.shape[1] > 2 else None
285
- y_pred_q90 = pred_t1[:, -2] if pred_t1.shape[1] > 2 else None
286
- y_pred_q02 = pred_t1[:, 0] if pred_t1.shape[1] > 2 else None
287
- y_pred_q98 = pred_t1[:, -1] if pred_t1.shape[1] > 2 else None
288
- else:
289
- y_pred_median = pred_np.flatten()
290
- pred_t1 = None
291
- y_pred_q10 = y_pred_q90 = y_pred_q02 = y_pred_q98 = None
292
-
293
- n = min(len(y_actual), len(y_pred_median))
294
- test_metrics = compute_all_metrics(
295
- y_actual[:n],
296
- y_pred_median[:n],
297
- y_pred_q10=y_pred_q10[:n] if y_pred_q10 is not None else None,
298
- y_pred_q90=y_pred_q90[:n] if y_pred_q90 is not None else None,
299
- y_pred_q02=y_pred_q02[:n] if y_pred_q02 is not None else None,
300
- y_pred_q98=y_pred_q98[:n] if y_pred_q98 is not None else None,
301
- y_pred_quantiles=pred_t1[:n] if pred_t1 is not None else None,
302
- )
303
- if pred_np.ndim == 3:
304
- n_path = min(len(y_actual_path), len(pred_np))
305
- weekly_metrics = compute_weekly_metrics(
306
- y_actual_path[:n_path],
307
- pred_np[:n_path],
308
- quantiles=cfg.model.quantiles,
309
- )
310
- test_metrics.update(weekly_metrics)
311
  test_metrics["ensemble_size"] = ensemble_size
312
  logger.info("Test metrics: %s", {k: f"{v:.4f}" for k, v in test_metrics.items()})
313
 
 
 
314
  calibration_artifact = _write_conformal_calibration_artifact(
315
  cfg=cfg,
316
  model=model,
@@ -506,7 +573,7 @@ def _apply_optuna_results(cfg: TFTASROConfig) -> TFTASROConfig:
506
  params["lambda_madl"] = max(float(params["lambda_madl"]), 0.30)
507
 
508
  logger.info(
509
- "Loaded Optuna best params (trial #%d, val_loss=%.4f): %s",
510
  data.get("best_trial", -1),
511
  data.get("best_value", float("nan")),
512
  params,
 
58
  "batch_size": 32,
59
  }
60
 
61
+ REQUIRED_PROMOTABLE_METRICS = (
62
+ "weekly_directional_accuracy",
63
+ "weekly_magnitude_ratio",
64
+ "weekly_tail_capture_rate",
65
+ "weekly_pi80_coverage",
66
+ "weekly_sample_count",
67
+ "weekly_quantile_crossing_rate",
68
+ "quantile_crossing_rate",
69
+ )
70
+
71
+
72
+ def _validate_quantile_prediction_shape(pred_np: np.ndarray, cfg: TFTASROConfig) -> None:
73
+ if pred_np.ndim != 3:
74
+ raise RuntimeError(
75
+ f"Expected quantile prediction tensor [n, horizon, q], got shape={pred_np.shape}. "
76
+ "Weekly quality gate cannot run without full multi-horizon quantile predictions."
77
+ )
78
+ if pred_np.shape[1] < cfg.forecast.primary_horizon_days:
79
+ raise RuntimeError(
80
+ f"Prediction horizon too short: got {pred_np.shape[1]}, "
81
+ f"need {cfg.forecast.primary_horizon_days}"
82
+ )
83
+ if pred_np.shape[2] != len(cfg.model.quantiles):
84
+ raise RuntimeError(
85
+ f"Quantile dim mismatch: {pred_np.shape[2]} != {len(cfg.model.quantiles)}"
86
+ )
87
+
88
+
89
+ def _predict_quantiles_to_np(mdl, dataloader, cfg: TFTASROConfig) -> np.ndarray:
90
+ pred = mdl.predict(dataloader, mode="quantiles")
91
+ pred_np = pred.cpu().numpy() if hasattr(pred, "cpu") else np.asarray(pred)
92
+ _validate_quantile_prediction_shape(pred_np, cfg)
93
+ return pred_np
94
+
95
+
96
+ def _require_promotable_metrics(metrics: dict) -> None:
97
+ missing = [
98
+ key for key in REQUIRED_PROMOTABLE_METRICS
99
+ if key not in metrics or metrics.get(key) is None
100
+ ]
101
+ if missing:
102
+ raise RuntimeError(
103
+ f"Required TFT promotion metrics missing after evaluation: {missing}. "
104
+ "Refusing to write promotable TFT metadata."
105
+ )
106
+
107
+
108
+ def _compute_test_metrics_from_quantiles(
109
+ y_actual_path: np.ndarray,
110
+ pred_np: np.ndarray,
111
+ cfg: TFTASROConfig,
112
+ ) -> dict[str, float]:
113
+ from deep_learning.training.metrics import compute_all_metrics, compute_weekly_metrics, select_prediction_horizon
114
+
115
+ pred_np = np.asarray(pred_np)
116
+ _validate_quantile_prediction_shape(pred_np, cfg)
117
+
118
+ median_idx = len(cfg.model.quantiles) // 2
119
+ pred_t1 = pred_np[:, 0, :]
120
+ y_pred_median = pred_t1[:, median_idx]
121
+ y_pred_q10 = pred_t1[:, 1]
122
+ y_pred_q90 = pred_t1[:, -2]
123
+ y_pred_q02 = pred_t1[:, 0]
124
+ y_pred_q98 = pred_t1[:, -1]
125
+
126
+ y_actual = select_prediction_horizon(y_actual_path, horizon_idx=0)
127
+ n = min(len(y_actual), len(y_pred_median))
128
+ test_metrics = compute_all_metrics(
129
+ y_actual[:n],
130
+ y_pred_median[:n],
131
+ y_pred_q10=y_pred_q10[:n],
132
+ y_pred_q90=y_pred_q90[:n],
133
+ y_pred_q02=y_pred_q02[:n],
134
+ y_pred_q98=y_pred_q98[:n],
135
+ y_pred_quantiles=pred_t1[:n],
136
+ )
137
+
138
+ n_path = min(len(y_actual_path), len(pred_np))
139
+ weekly_metrics = compute_weekly_metrics(
140
+ y_actual_path[:n_path],
141
+ pred_np[:n_path],
142
+ quantiles=cfg.model.quantiles,
143
+ horizon=cfg.forecast.primary_horizon_days,
144
+ )
145
+ test_metrics.update(weekly_metrics)
146
+ _require_promotable_metrics(test_metrics)
147
+ return test_metrics
148
+
149
 
150
  def train_tft_model(
151
  cfg: Optional[TFTASROConfig] = None,
 
172
  from deep_learning.data.feature_store import build_tft_dataframe
173
  from deep_learning.data.dataset import build_datasets, create_dataloaders
174
  from deep_learning.models.tft_copper import create_tft_model, get_variable_importance, format_prediction
 
175
  from deep_learning.training.callbacks import CurriculumLossScheduler, SWACallback
176
 
177
  if cfg is None:
 
236
  cfg.model.dropout, cfg.model.attention_head_size,
237
  cfg.model.learning_rate, cfg.model.gradient_clip_val,
238
  )
239
+ if cfg.forecast.primary_horizon_days == 5:
240
+ logger.info(
241
+ "Training data | samples=%d batch_size=%d batches/epoch=%d patience=%d",
242
+ len(training_ds), cfg.training.batch_size, n_batches,
243
+ cfg.training.early_stopping_patience,
244
+ )
245
+ logger.info(
246
+ "Weekly loss | weekly_q=%.2f t1_q=%.2f directional=%.2f magnitude=%.2f vol=%.2f",
247
+ cfg.weekly_loss.lambda_weekly_quantile,
248
+ cfg.weekly_loss.lambda_t1_quantile,
249
+ cfg.weekly_loss.lambda_directional,
250
+ cfg.weekly_loss.lambda_magnitude,
251
+ cfg.weekly_loss.lambda_vol,
252
+ )
253
+ else:
254
+ logger.info(
255
+ "Training data | samples=%d batch_size=%d batches/epoch=%d "
256
+ "patience=%d w_quantile=%.2f w_sharpe=%.2f lambda_vol=%.2f",
257
+ len(training_ds), cfg.training.batch_size, n_batches,
258
+ cfg.training.early_stopping_patience,
259
+ cfg.asro.lambda_quantile, 1.0 - cfg.asro.lambda_quantile,
260
+ cfg.asro.lambda_vol,
261
+ )
262
  logger.info(
263
  "Model params | total=%s trainable=%s",
264
  f"{total_params:,}", f"{trainable_params:,}",
 
340
  batch[1][0] if isinstance(batch[1], (list, tuple)) else batch[1]
341
  )
342
  y_actual_path = torch.cat(y_actual_parts).cpu().numpy()
 
 
343
  # Gather top-k checkpoint paths
344
  best_k = getattr(trainer.checkpoint_callback, "best_k_models", {})
345
  ckpt_paths = sorted(best_k.keys(), key=lambda p: best_k[p]) if best_k else []
 
347
  # Always include the just-trained model as a baseline
348
  all_pred_arrays = []
349
 
 
 
 
 
 
350
  # Predictions from the best model (already in memory)
351
+ all_pred_arrays.append(_predict_quantiles_to_np(model, test_dl, cfg))
352
 
353
  # Load additional checkpoints for ensemble
354
  for cp in ckpt_paths:
 
356
  continue # already have this one
357
  try:
358
  ckpt_model = load_tft_model(str(cp))
359
+ all_pred_arrays.append(_predict_quantiles_to_np(ckpt_model, test_dl, cfg))
360
  del ckpt_model
361
  except Exception as exc:
362
+ logger.warning("Skipping incompatible ensemble checkpoint %s: %s", cp, exc)
363
 
364
  ensemble_size = len(all_pred_arrays)
365
  logger.info(
 
372
  else:
373
  pred_np = all_pred_arrays[0]
374
 
375
+ test_metrics = _compute_test_metrics_from_quantiles(y_actual_path, pred_np, cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  test_metrics["ensemble_size"] = ensemble_size
377
  logger.info("Test metrics: %s", {k: f"{v:.4f}" for k, v in test_metrics.items()})
378
 
379
+ _require_promotable_metrics(test_metrics)
380
+
381
  calibration_artifact = _write_conformal_calibration_artifact(
382
  cfg=cfg,
383
  model=model,
 
573
  params["lambda_madl"] = max(float(params["lambda_madl"]), 0.30)
574
 
575
  logger.info(
576
+ "Loaded Optuna best params (trial #%d, weekly_objective=%.4f): %s",
577
  data.get("best_trial", -1),
578
  data.get("best_value", float("nan")),
579
  params,