ifieryarrows commited on
Commit
916d006
·
verified ·
1 Parent(s): 3fd2ce8

Sync from GitHub (tests passed)

Browse files
Files changed (1) hide show
  1. deep_learning/training/hyperopt.py +56 -10
deep_learning/training/hyperopt.py CHANGED
@@ -50,16 +50,17 @@ def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
50
  dropout=trial.suggest_float("dropout", 0.1, 0.5, step=0.05),
51
  hidden_continuous_size=trial.suggest_int("hidden_continuous_size", 8, 32, step=8),
52
  quantiles=base_cfg.model.quantiles,
53
- # Cap at 1e-3: two consecutive Optuna runs both selected ~3-4e-3 which
54
- # caused the model to converge in 1 epoch then diverge. 1e-3 is the
55
- # practical upper bound for stable TFT training on ~300 samples.
56
- learning_rate=trial.suggest_float("learning_rate", 5e-5, 1e-3, log=True),
57
  reduce_on_plateau_patience=4,
58
  gradient_clip_val=trial.suggest_float("gradient_clip_val", 0.5, 2.0, step=0.5),
59
  )
60
 
61
  asro_cfg = ASROConfig(
62
- lambda_vol=trial.suggest_float("lambda_vol", 0.1, 0.4, step=0.05),
 
 
63
  # lambda_quantile is the explicit w_quantile weight (w_sharpe = 1 - w_q)
64
  lambda_quantile=trial.suggest_float("lambda_quantile", 0.2, 0.6, step=0.05),
65
  risk_free_rate=0.0,
@@ -94,10 +95,15 @@ def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
94
 
95
  def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
96
  """
97
- Single Optuna trial: train a TFT variant and return the validation metric.
98
 
99
- Optimises for a composite score:
100
- score = -val_loss + 0.5 * directional_accuracy
 
 
 
 
 
101
  """
102
  try:
103
  import lightning.pytorch as pl
@@ -110,9 +116,10 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
110
  except ImportError:
111
  from optuna.integration import PyTorchLightningPruningCallback # type: ignore[no-redef]
112
 
 
 
113
  from deep_learning.data.dataset import build_datasets, create_dataloaders
114
  from deep_learning.models.tft_copper import create_tft_model
115
- from deep_learning.training.metrics import compute_all_metrics
116
 
117
  trial_cfg = create_trial_config(trial, base_cfg)
118
  master_df, tv_unknown, tv_known, target_cols = master_data
@@ -155,7 +162,46 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
155
  if val_loss is None:
156
  return float("inf")
157
 
158
- return float(val_loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  def run_hyperopt(
 
50
  dropout=trial.suggest_float("dropout", 0.1, 0.5, step=0.05),
51
  hidden_continuous_size=trial.suggest_int("hidden_continuous_size", 8, 32, step=8),
52
  quantiles=base_cfg.model.quantiles,
53
+ # Range [1e-4, 1e-3]: LR < 1e-4 produces near-zero pred_std (VR=0.14);
54
+ # LR > 1e-3 causes 1-epoch divergence. This band is the stable zone.
55
+ learning_rate=trial.suggest_float("learning_rate", 1e-4, 1e-3, log=True),
 
56
  reduce_on_plateau_patience=4,
57
  gradient_clip_val=trial.suggest_float("gradient_clip_val", 0.5, 2.0, step=0.5),
58
  )
59
 
60
  asro_cfg = ASROConfig(
61
+ # Floor at 0.25: three Optuna runs consistently selected 0.30-0.35.
62
+ # Lower values let the model collapse to near-zero pred_std.
63
+ lambda_vol=trial.suggest_float("lambda_vol", 0.25, 0.45, step=0.05),
64
  # lambda_quantile is the explicit w_quantile weight (w_sharpe = 1 - w_q)
65
  lambda_quantile=trial.suggest_float("lambda_quantile", 0.2, 0.6, step=0.05),
66
  risk_free_rate=0.0,
 
95
 
96
  def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
97
  """
98
+ Single Optuna trial: train a TFT variant and return a composite score.
99
 
100
+ Composite objective (lower is better):
101
+ score = val_loss - 0.5 * variance_penalty
102
+
103
+ The variance penalty pushes Optuna away from "flat but directionally correct"
104
+ configs that game the Sharpe component with near-zero pred_std. The penalty
105
+ fires when variance_ratio < 0.5 (i.e. predictions capture less than half the
106
+ actual volatility).
107
  """
108
  try:
109
  import lightning.pytorch as pl
 
116
  except ImportError:
117
  from optuna.integration import PyTorchLightningPruningCallback # type: ignore[no-redef]
118
 
119
+ import numpy as np
120
+ import torch
121
  from deep_learning.data.dataset import build_datasets, create_dataloaders
122
  from deep_learning.models.tft_copper import create_tft_model
 
123
 
124
  trial_cfg = create_trial_config(trial, base_cfg)
125
  master_df, tv_unknown, tv_known, target_cols = master_data
 
162
  if val_loss is None:
163
  return float("inf")
164
 
165
+ # --- Variance-ratio penalty on validation set ---
166
+ # Prevents Optuna from selecting configs that produce near-zero pred_std
167
+ # (which games Sharpe by being "flat but directionally correct").
168
+ variance_penalty = 0.0
169
+ try:
170
+ pred_tensor = model.predict(val_dl, mode="quantiles")
171
+ if hasattr(pred_tensor, "cpu"):
172
+ pred_np = pred_tensor.cpu().numpy()
173
+ else:
174
+ pred_np = np.array(pred_tensor)
175
+
176
+ median_idx = len(trial_cfg.model.quantiles) // 2
177
+ y_pred = pred_np[:, 0, median_idx] if pred_np.ndim == 3 else pred_np.flatten()
178
+
179
+ y_actual_parts = []
180
+ for batch in val_dl:
181
+ y_actual_parts.append(batch[1][0] if isinstance(batch[1], (list, tuple)) else batch[1])
182
+ y_actual = torch.cat(y_actual_parts).cpu().numpy().flatten()
183
+
184
+ n = min(len(y_actual), len(y_pred))
185
+ pred_std = float(y_pred[:n].std())
186
+ actual_std = float(y_actual[:n].std())
187
+ vr = pred_std / actual_std if actual_std > 1e-9 else 0.0
188
+
189
+ # Penalty activates when VR < 0.5 (predictions cover less than half
190
+ # the real volatility). Scaled so VR=0 → penalty=0.5, VR=0.5 → 0.
191
+ if vr < 0.5:
192
+ variance_penalty = 0.5 * (1.0 - vr / 0.5)
193
+
194
+ trial.set_user_attr("variance_ratio", round(vr, 4))
195
+ trial.set_user_attr("pred_std", round(pred_std, 6))
196
+ except Exception as exc:
197
+ logger.debug("Trial %d variance check failed: %s", trial.number, exc)
198
+
199
+ score = float(val_loss) + variance_penalty
200
+ logger.info(
201
+ "Trial %d: val_loss=%.4f vr_penalty=%.4f → score=%.4f",
202
+ trial.number, float(val_loss), variance_penalty, score,
203
+ )
204
+ return score
205
 
206
 
207
  def run_hyperopt(