Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files
deep_learning/training/hyperopt.py
CHANGED
|
@@ -43,26 +43,31 @@ def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
|
|
| 43 |
model_cfg = TFTModelConfig(
|
| 44 |
max_encoder_length=trial.suggest_int("max_encoder_length", 30, 90, step=10),
|
| 45 |
max_prediction_length=base_cfg.model.max_prediction_length,
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
quantiles=base_cfg.model.quantiles,
|
| 51 |
-
learning_rate=trial.suggest_float("learning_rate",
|
| 52 |
reduce_on_plateau_patience=4,
|
| 53 |
-
gradient_clip_val=trial.suggest_float("gradient_clip_val", 0.
|
| 54 |
)
|
| 55 |
|
| 56 |
asro_cfg = ASROConfig(
|
| 57 |
-
lambda_vol=trial.suggest_float("lambda_vol", 0.1, 0.
|
| 58 |
-
lambda_quantile
|
|
|
|
| 59 |
risk_free_rate=0.0,
|
| 60 |
)
|
| 61 |
|
| 62 |
training_cfg = TrainingConfig(
|
| 63 |
max_epochs=50,
|
| 64 |
early_stopping_patience=8,
|
| 65 |
-
|
|
|
|
|
|
|
| 66 |
val_ratio=base_cfg.training.val_ratio,
|
| 67 |
test_ratio=base_cfg.training.test_ratio,
|
| 68 |
lookback_days=base_cfg.training.lookback_days,
|
|
|
|
| 43 |
model_cfg = TFTModelConfig(
|
| 44 |
max_encoder_length=trial.suggest_int("max_encoder_length", 30, 90, step=10),
|
| 45 |
max_prediction_length=base_cfg.model.max_prediction_length,
|
| 46 |
+
# Cap at 64: beyond that the VSN encoder explodes to 3M+ params for our
|
| 47 |
+
# 313-sample dataset, causing the same overfitting we already saw at 64.
|
| 48 |
+
hidden_size=trial.suggest_int("hidden_size", 16, 64, step=16),
|
| 49 |
+
attention_head_size=trial.suggest_int("attention_head_size", 1, 4),
|
| 50 |
+
dropout=trial.suggest_float("dropout", 0.1, 0.5, step=0.05),
|
| 51 |
+
hidden_continuous_size=trial.suggest_int("hidden_continuous_size", 8, 32, step=8),
|
| 52 |
quantiles=base_cfg.model.quantiles,
|
| 53 |
+
learning_rate=trial.suggest_float("learning_rate", 5e-5, 5e-3, log=True),
|
| 54 |
reduce_on_plateau_patience=4,
|
| 55 |
+
gradient_clip_val=trial.suggest_float("gradient_clip_val", 0.5, 2.0, step=0.5),
|
| 56 |
)
|
| 57 |
|
| 58 |
asro_cfg = ASROConfig(
|
| 59 |
+
lambda_vol=trial.suggest_float("lambda_vol", 0.1, 0.4, step=0.05),
|
| 60 |
+
# lambda_quantile is the explicit w_quantile weight (w_sharpe = 1 - w_q)
|
| 61 |
+
lambda_quantile=trial.suggest_float("lambda_quantile", 0.2, 0.6, step=0.05),
|
| 62 |
risk_free_rate=0.0,
|
| 63 |
)
|
| 64 |
|
| 65 |
training_cfg = TrainingConfig(
|
| 66 |
max_epochs=50,
|
| 67 |
early_stopping_patience=8,
|
| 68 |
+
# Include 16 which gives 19 batches/epoch (vs 4 at batch_size=64)
|
| 69 |
+
# — more gradient steps per epoch → more stable convergence.
|
| 70 |
+
batch_size=trial.suggest_categorical("batch_size", [16, 32, 64]),
|
| 71 |
val_ratio=base_cfg.training.val_ratio,
|
| 72 |
test_ratio=base_cfg.training.test_ratio,
|
| 73 |
lookback_days=base_cfg.training.lookback_days,
|
deep_learning/training/trainer.py
CHANGED
|
@@ -68,7 +68,13 @@ def train_tft_model(
|
|
| 68 |
if cfg is None:
|
| 69 |
cfg = get_tft_config()
|
| 70 |
|
| 71 |
-
# ----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
try:
|
| 73 |
from deep_learning.models.losses import debug_asro_loss_direction
|
| 74 |
debug = debug_asro_loss_direction()
|
|
@@ -270,6 +276,58 @@ def train_tft_model(
|
|
| 270 |
return result
|
| 271 |
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
def _persist_tft_metadata(symbol: str, result: dict) -> None:
|
| 274 |
"""Save TFT model metadata to DB."""
|
| 275 |
try:
|
|
|
|
| 68 |
if cfg is None:
|
| 69 |
cfg = get_tft_config()
|
| 70 |
|
| 71 |
+
# ---- 0a. Load Optuna best params if available ----
|
| 72 |
+
# When the hyperopt step ran before this trainer, it writes best params to
|
| 73 |
+
# optuna_results.json. We apply those params over the default config so that
|
| 74 |
+
# the final training run actually benefits from the search.
|
| 75 |
+
cfg = _apply_optuna_results(cfg)
|
| 76 |
+
|
| 77 |
+
# ---- 0b. ASRO loss sanity check (runs before any training) ----
|
| 78 |
try:
|
| 79 |
from deep_learning.models.losses import debug_asro_loss_direction
|
| 80 |
debug = debug_asro_loss_direction()
|
|
|
|
| 276 |
return result
|
| 277 |
|
| 278 |
|
| 279 |
+
def _apply_optuna_results(cfg: TFTASROConfig) -> TFTASROConfig:
|
| 280 |
+
"""
|
| 281 |
+
If an optuna_results.json exists in the checkpoint directory, overlay the
|
| 282 |
+
best hyperparameters onto cfg and return the updated config. This connects
|
| 283 |
+
the hyperopt step to the final training run so search results are not wasted.
|
| 284 |
+
"""
|
| 285 |
+
import json
|
| 286 |
+
from dataclasses import replace
|
| 287 |
+
from deep_learning.config import ASROConfig, TFTModelConfig, TrainingConfig
|
| 288 |
+
|
| 289 |
+
results_path = Path(cfg.training.checkpoint_dir) / "optuna_results.json"
|
| 290 |
+
if not results_path.exists():
|
| 291 |
+
return cfg
|
| 292 |
+
|
| 293 |
+
try:
|
| 294 |
+
data = json.loads(results_path.read_text())
|
| 295 |
+
params = data.get("best_params", {})
|
| 296 |
+
if not params:
|
| 297 |
+
return cfg
|
| 298 |
+
|
| 299 |
+
model_overrides = {
|
| 300 |
+
k: params[k] for k in (
|
| 301 |
+
"hidden_size", "attention_head_size", "dropout",
|
| 302 |
+
"hidden_continuous_size", "learning_rate",
|
| 303 |
+
"gradient_clip_val", "max_encoder_length",
|
| 304 |
+
) if k in params
|
| 305 |
+
}
|
| 306 |
+
asro_overrides = {
|
| 307 |
+
k: params[k] for k in ("lambda_vol", "lambda_quantile")
|
| 308 |
+
if k in params
|
| 309 |
+
}
|
| 310 |
+
training_overrides = {
|
| 311 |
+
k: params[k] for k in ("batch_size",) if k in params
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
new_model = replace(cfg.model, **model_overrides) if model_overrides else cfg.model
|
| 315 |
+
new_asro = replace(cfg.asro, **asro_overrides) if asro_overrides else cfg.asro
|
| 316 |
+
new_training = replace(cfg.training, **training_overrides) if training_overrides else cfg.training
|
| 317 |
+
|
| 318 |
+
logger.info(
|
| 319 |
+
"Loaded Optuna best params (trial #%d, val_loss=%.4f): %s",
|
| 320 |
+
data.get("best_trial", -1),
|
| 321 |
+
data.get("best_value", float("nan")),
|
| 322 |
+
params,
|
| 323 |
+
)
|
| 324 |
+
return replace(cfg, model=new_model, asro=new_asro, training=new_training)
|
| 325 |
+
|
| 326 |
+
except Exception as exc:
|
| 327 |
+
logger.warning("Could not apply Optuna results: %s", exc)
|
| 328 |
+
return cfg
|
| 329 |
+
|
| 330 |
+
|
| 331 |
def _persist_tft_metadata(symbol: str, result: dict) -> None:
|
| 332 |
"""Save TFT model metadata to DB."""
|
| 333 |
try:
|