ifieryarrows commited on
Commit
0b39593
·
verified ·
1 Parent(s): d066295

Sync from GitHub (tests passed)

Browse files
deep_learning/training/hyperopt.py CHANGED
@@ -43,26 +43,31 @@ def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
43
  model_cfg = TFTModelConfig(
44
  max_encoder_length=trial.suggest_int("max_encoder_length", 30, 90, step=10),
45
  max_prediction_length=base_cfg.model.max_prediction_length,
46
- hidden_size=trial.suggest_int("hidden_size", 32, 128, step=16),
47
- attention_head_size=trial.suggest_int("attention_head_size", 1, 8),
48
- dropout=trial.suggest_float("dropout", 0.05, 0.3, step=0.05),
49
- hidden_continuous_size=trial.suggest_int("hidden_continuous_size", 16, 64, step=8),
 
 
50
  quantiles=base_cfg.model.quantiles,
51
- learning_rate=trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True),
52
  reduce_on_plateau_patience=4,
53
- gradient_clip_val=trial.suggest_float("gradient_clip_val", 0.1, 1.0, step=0.1),
54
  )
55
 
56
  asro_cfg = ASROConfig(
57
- lambda_vol=trial.suggest_float("lambda_vol", 0.1, 0.5, step=0.05),
58
- lambda_quantile=trial.suggest_float("lambda_quantile", 0.1, 0.5, step=0.05),
 
59
  risk_free_rate=0.0,
60
  )
61
 
62
  training_cfg = TrainingConfig(
63
  max_epochs=50,
64
  early_stopping_patience=8,
65
- batch_size=trial.suggest_categorical("batch_size", [32, 64, 128]),
 
 
66
  val_ratio=base_cfg.training.val_ratio,
67
  test_ratio=base_cfg.training.test_ratio,
68
  lookback_days=base_cfg.training.lookback_days,
 
43
  model_cfg = TFTModelConfig(
44
  max_encoder_length=trial.suggest_int("max_encoder_length", 30, 90, step=10),
45
  max_prediction_length=base_cfg.model.max_prediction_length,
46
+ # Cap at 64: beyond that the VSN encoder explodes to 3M+ params for our
47
+ # 313-sample dataset, causing the same overfitting we already saw at 64.
48
+ hidden_size=trial.suggest_int("hidden_size", 16, 64, step=16),
49
+ attention_head_size=trial.suggest_int("attention_head_size", 1, 4),
50
+ dropout=trial.suggest_float("dropout", 0.1, 0.5, step=0.05),
51
+ hidden_continuous_size=trial.suggest_int("hidden_continuous_size", 8, 32, step=8),
52
  quantiles=base_cfg.model.quantiles,
53
+ learning_rate=trial.suggest_float("learning_rate", 5e-5, 5e-3, log=True),
54
  reduce_on_plateau_patience=4,
55
+ gradient_clip_val=trial.suggest_float("gradient_clip_val", 0.5, 2.0, step=0.5),
56
  )
57
 
58
  asro_cfg = ASROConfig(
59
+ lambda_vol=trial.suggest_float("lambda_vol", 0.1, 0.4, step=0.05),
60
+ # lambda_quantile is the explicit w_quantile weight (w_sharpe = 1 - w_q)
61
+ lambda_quantile=trial.suggest_float("lambda_quantile", 0.2, 0.6, step=0.05),
62
  risk_free_rate=0.0,
63
  )
64
 
65
  training_cfg = TrainingConfig(
66
  max_epochs=50,
67
  early_stopping_patience=8,
68
+ # Include 16 which gives 19 batches/epoch (vs 4 at batch_size=64)
69
+ # — more gradient steps per epoch → more stable convergence.
70
+ batch_size=trial.suggest_categorical("batch_size", [16, 32, 64]),
71
  val_ratio=base_cfg.training.val_ratio,
72
  test_ratio=base_cfg.training.test_ratio,
73
  lookback_days=base_cfg.training.lookback_days,
deep_learning/training/trainer.py CHANGED
@@ -68,7 +68,13 @@ def train_tft_model(
68
  if cfg is None:
69
  cfg = get_tft_config()
70
 
71
- # ---- 0. ASRO loss sanity check (runs before any training) ----
 
 
 
 
 
 
72
  try:
73
  from deep_learning.models.losses import debug_asro_loss_direction
74
  debug = debug_asro_loss_direction()
@@ -270,6 +276,58 @@ def train_tft_model(
270
  return result
271
 
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  def _persist_tft_metadata(symbol: str, result: dict) -> None:
274
  """Save TFT model metadata to DB."""
275
  try:
 
68
  if cfg is None:
69
  cfg = get_tft_config()
70
 
71
+ # ---- 0a. Load Optuna best params if available ----
72
+ # When the hyperopt step ran before this trainer, it writes best params to
73
+ # optuna_results.json. We apply those params over the default config so that
74
+ # the final training run actually benefits from the search.
75
+ cfg = _apply_optuna_results(cfg)
76
+
77
+ # ---- 0b. ASRO loss sanity check (runs before any training) ----
78
  try:
79
  from deep_learning.models.losses import debug_asro_loss_direction
80
  debug = debug_asro_loss_direction()
 
276
  return result
277
 
278
 
279
+ def _apply_optuna_results(cfg: TFTASROConfig) -> TFTASROConfig:
280
+ """
281
+ If an optuna_results.json exists in the checkpoint directory, overlay the
282
+ best hyperparameters onto cfg and return the updated config. This connects
283
+ the hyperopt step to the final training run so search results are not wasted.
284
+ """
285
+ import json
286
+ from dataclasses import replace
287
+ from deep_learning.config import ASROConfig, TFTModelConfig, TrainingConfig
288
+
289
+ results_path = Path(cfg.training.checkpoint_dir) / "optuna_results.json"
290
+ if not results_path.exists():
291
+ return cfg
292
+
293
+ try:
294
+ data = json.loads(results_path.read_text())
295
+ params = data.get("best_params", {})
296
+ if not params:
297
+ return cfg
298
+
299
+ model_overrides = {
300
+ k: params[k] for k in (
301
+ "hidden_size", "attention_head_size", "dropout",
302
+ "hidden_continuous_size", "learning_rate",
303
+ "gradient_clip_val", "max_encoder_length",
304
+ ) if k in params
305
+ }
306
+ asro_overrides = {
307
+ k: params[k] for k in ("lambda_vol", "lambda_quantile")
308
+ if k in params
309
+ }
310
+ training_overrides = {
311
+ k: params[k] for k in ("batch_size",) if k in params
312
+ }
313
+
314
+ new_model = replace(cfg.model, **model_overrides) if model_overrides else cfg.model
315
+ new_asro = replace(cfg.asro, **asro_overrides) if asro_overrides else cfg.asro
316
+ new_training = replace(cfg.training, **training_overrides) if training_overrides else cfg.training
317
+
318
+ logger.info(
319
+ "Loaded Optuna best params (trial #%d, val_loss=%.4f): %s",
320
+ data.get("best_trial", -1),
321
+ data.get("best_value", float("nan")),
322
+ params,
323
+ )
324
+ return replace(cfg, model=new_model, asro=new_asro, training=new_training)
325
+
326
+ except Exception as exc:
327
+ logger.warning("Could not apply Optuna results: %s", exc)
328
+ return cfg
329
+
330
+
331
  def _persist_tft_metadata(symbol: str, result: dict) -> None:
332
  """Save TFT model metadata to DB."""
333
  try: