Spaces:
Running
Running
| """ | |
| Optuna-based Hyperparameter Optimization for TFT-ASRO. | |
| Searches across model architecture, training, and ASRO loss parameters | |
| using Tree-structured Parzen Estimator (TPE) with early pruning. | |
| Usage: | |
| python -m deep_learning.training.hyperopt --n-trials 50 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import warnings | |
| from dataclasses import replace | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| warnings.filterwarnings( | |
| "ignore", | |
| message="X does not have valid feature names", | |
| category=UserWarning, | |
| module="sklearn", | |
| ) | |
| from deep_learning.config import ( | |
| ASROConfig, | |
| TFTASROConfig, | |
| TFTModelConfig, | |
| TrainingConfig, | |
| get_tft_config, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig: | |
| """Map an Optuna trial to a TFT-ASRO configuration.""" | |
| model_cfg = TFTModelConfig( | |
| max_encoder_length=trial.suggest_int("max_encoder_length", 30, 90, step=10), | |
| max_prediction_length=base_cfg.model.max_prediction_length, | |
| # Floor at 32: hidden=16 with dropout>0.3 leaves ~8 active neurons, | |
| # compressing output distribution and preventing amplitude learning. | |
| hidden_size=trial.suggest_int("hidden_size", 32, 64, step=16), | |
| attention_head_size=trial.suggest_int("attention_head_size", 1, 4), | |
| # Cap at 0.35: dropout=0.5 with small hidden_size collapses the output | |
| # range β the model physically cannot produce large predictions. | |
| dropout=trial.suggest_float("dropout", 0.1, 0.35, step=0.05), | |
| hidden_continuous_size=trial.suggest_int("hidden_continuous_size", 8, 32, step=8), | |
| quantiles=base_cfg.model.quantiles, | |
| # Range [1e-4, 1e-3]: LR < 1e-4 produces near-zero pred_std (VR=0.14); | |
| # LR > 1e-3 causes 1-epoch divergence. This band is the stable zone. | |
| learning_rate=trial.suggest_float("learning_rate", 1e-4, 1e-3, log=True), | |
| reduce_on_plateau_patience=4, | |
| gradient_clip_val=trial.suggest_float("gradient_clip_val", 0.5, 2.0, step=0.5), | |
| ) | |
| asro_cfg = ASROConfig( | |
| # Floor at 0.25: three Optuna runs consistently selected 0.30-0.35. | |
| # Lower values let the model collapse to near-zero pred_std. | |
| lambda_vol=trial.suggest_float("lambda_vol", 0.25, 0.45, step=0.05), | |
| # lambda_quantile is the explicit w_quantile weight (w_sharpe = 1 - w_q) | |
| lambda_quantile=trial.suggest_float("lambda_quantile", 0.2, 0.6, step=0.05), | |
| risk_free_rate=0.0, | |
| ) | |
| training_cfg = TrainingConfig( | |
| max_epochs=50, | |
| early_stopping_patience=8, | |
| # Include 16 which gives 19 batches/epoch (vs 4 at batch_size=64) | |
| # β more gradient steps per epoch β more stable convergence. | |
| batch_size=trial.suggest_categorical("batch_size", [16, 32, 64]), | |
| val_ratio=base_cfg.training.val_ratio, | |
| test_ratio=base_cfg.training.test_ratio, | |
| lookback_days=base_cfg.training.lookback_days, | |
| seed=base_cfg.training.seed, | |
| num_workers=base_cfg.training.num_workers, | |
| optuna_n_trials=base_cfg.training.optuna_n_trials, | |
| checkpoint_dir=str(Path(base_cfg.training.checkpoint_dir) / f"trial_{trial.number}"), | |
| best_model_path=str(Path(base_cfg.training.checkpoint_dir) / f"trial_{trial.number}" / "best.ckpt"), | |
| ) | |
| return TFTASROConfig( | |
| embedding=base_cfg.embedding, | |
| sentiment=base_cfg.sentiment, | |
| lme=base_cfg.lme, | |
| model=model_cfg, | |
| asro=asro_cfg, | |
| training=training_cfg, | |
| feature_store=base_cfg.feature_store, | |
| ) | |
| def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float: | |
| """ | |
| Single Optuna trial: train a TFT variant and return a composite score. | |
| Composite objective (lower is better): | |
| score = val_loss + variance_penalty | |
| Two-sided variance penalty keeps predictions in a healthy amplitude zone: | |
| VR < 0.5 β strong penalty (2.0Γ) β flat predictions are useless | |
| 0.5β1.5 β no penalty β wide healthy zone, not a narrow band | |
| VR > 1.5 β gentle penalty (0.5Γ) β overconfident but still has signal | |
| """ | |
| try: | |
| import lightning.pytorch as pl | |
| from lightning.pytorch.callbacks import EarlyStopping | |
| except ImportError: | |
| import pytorch_lightning as pl # type: ignore[no-redef] | |
| from pytorch_lightning.callbacks import EarlyStopping # type: ignore[no-redef] | |
| try: | |
| from optuna_integration.pytorch_lightning import PyTorchLightningPruningCallback | |
| except ImportError: | |
| from optuna.integration import PyTorchLightningPruningCallback # type: ignore[no-redef] | |
| import numpy as np | |
| import torch | |
| from deep_learning.data.dataset import build_datasets, create_dataloaders | |
| from deep_learning.models.tft_copper import create_tft_model | |
| trial_cfg = create_trial_config(trial, base_cfg) | |
| master_df, tv_unknown, tv_known, target_cols = master_data | |
| try: | |
| training_ds, validation_ds, test_ds = build_datasets( | |
| master_df, tv_unknown, tv_known, target_cols, trial_cfg, | |
| ) | |
| train_dl, val_dl, _ = create_dataloaders(training_ds, validation_ds, cfg=trial_cfg) | |
| model = create_tft_model(training_ds, trial_cfg, use_asro=True) | |
| except Exception as exc: | |
| logger.warning("Trial %d setup failed: %s", trial.number, exc) | |
| return float("inf") | |
| callbacks = [ | |
| EarlyStopping(monitor="val_loss", patience=trial_cfg.training.early_stopping_patience, mode="min"), | |
| PyTorchLightningPruningCallback(trial, monitor="val_loss"), | |
| ] | |
| ckpt_dir = Path(trial_cfg.training.checkpoint_dir) | |
| ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| trainer = pl.Trainer( | |
| max_epochs=trial_cfg.training.max_epochs, | |
| accelerator="auto", | |
| gradient_clip_val=trial_cfg.model.gradient_clip_val, | |
| callbacks=callbacks, | |
| enable_progress_bar=False, | |
| enable_model_summary=False, | |
| log_every_n_steps=20, | |
| ) | |
| try: | |
| trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) | |
| except Exception as exc: | |
| logger.warning("Trial %d training failed: %s", trial.number, exc) | |
| return float("inf") | |
| val_loss = trainer.callback_metrics.get("val_loss") | |
| if val_loss is None: | |
| return float("inf") | |
| # --- Variance-ratio penalty on validation set --- | |
| # Prevents Optuna from selecting configs that produce near-zero pred_std | |
| # (which games Sharpe by being "flat but directionally correct"). | |
| variance_penalty = 0.0 | |
| try: | |
| pred_tensor = model.predict(val_dl, mode="quantiles") | |
| if hasattr(pred_tensor, "cpu"): | |
| pred_np = pred_tensor.cpu().numpy() | |
| else: | |
| pred_np = np.array(pred_tensor) | |
| median_idx = len(trial_cfg.model.quantiles) // 2 | |
| y_pred = pred_np[:, 0, median_idx] if pred_np.ndim == 3 else pred_np.flatten() | |
| y_actual_parts = [] | |
| for batch in val_dl: | |
| y_actual_parts.append(batch[1][0] if isinstance(batch[1], (list, tuple)) else batch[1]) | |
| y_actual = torch.cat(y_actual_parts).cpu().numpy().flatten() | |
| n = min(len(y_actual), len(y_pred)) | |
| pred_std = float(y_pred[:n].std()) | |
| actual_std = float(y_actual[:n].std()) | |
| vr = pred_std / actual_std if actual_std > 1e-9 else 0.0 | |
| # Two-sided penalty with a wide healthy zone [0.5, 1.5]: | |
| # VR < 0.5 β strong penalty (flat predictions, the original problem) | |
| # 0.5β1.5 β no penalty (3Γ wide zone, not a narrow band) | |
| # VR > 1.5 β gentle penalty (overconfident, predictions louder than market) | |
| # | |
| # Asymmetric: too-flat is worse than too-loud (flat predictions are | |
| # useless; loud predictions at least carry directional signal). | |
| if vr < 0.5: | |
| variance_penalty = 2.0 * (1.0 - vr / 0.5) | |
| elif vr > 1.5: | |
| variance_penalty = 0.5 * (vr - 1.5) | |
| trial.set_user_attr("variance_ratio", round(vr, 4)) | |
| trial.set_user_attr("pred_std", round(pred_std, 6)) | |
| except Exception as exc: | |
| logger.debug("Trial %d variance check failed: %s", trial.number, exc) | |
| score = float(val_loss) + variance_penalty | |
| logger.info( | |
| "Trial %d: val_loss=%.4f vr_penalty=%.4f β score=%.4f", | |
| trial.number, float(val_loss), variance_penalty, score, | |
| ) | |
| return score | |
| def run_hyperopt( | |
| base_cfg: Optional[TFTASROConfig] = None, | |
| n_trials: int = 50, | |
| study_name: str = "tft_asro_optuna", | |
| storage: Optional[str] = None, | |
| ) -> dict: | |
| """ | |
| Launch Optuna hyperparameter search. | |
| Returns: | |
| Dict with best params, best value, and study summary. | |
| """ | |
| import optuna | |
| try: | |
| import lightning.pytorch as pl | |
| except ImportError: | |
| import pytorch_lightning as pl # type: ignore[no-redef] | |
| from app.db import SessionLocal, init_db | |
| from deep_learning.data.feature_store import build_tft_dataframe | |
| if base_cfg is None: | |
| base_cfg = get_tft_config() | |
| init_db() | |
| pl.seed_everything(base_cfg.training.seed) | |
| logger.info("Building feature store for hyperopt ...") | |
| with SessionLocal() as session: | |
| master_data = build_tft_dataframe(session, base_cfg) | |
| study = optuna.create_study( | |
| study_name=study_name, | |
| direction="minimize", | |
| storage=storage, | |
| load_if_exists=True, | |
| pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=5), | |
| ) | |
| study.optimize( | |
| lambda trial: _objective(trial, base_cfg, master_data), | |
| n_trials=n_trials, | |
| show_progress_bar=True, | |
| ) | |
| best = study.best_trial | |
| logger.info("Optuna best trial #%d: val_loss=%.6f", best.number, best.value) | |
| logger.info("Best params: %s", best.params) | |
| # Save alongside best_tft_asro.ckpt (tft/ root) so upload_tft_artifacts picks it up. | |
| results_path = Path(base_cfg.training.best_model_path).parent / "optuna_results.json" | |
| results_path.parent.mkdir(parents=True, exist_ok=True) | |
| results_path.write_text(json.dumps({ | |
| "best_trial": best.number, | |
| "best_value": best.value, | |
| "best_params": best.params, | |
| "n_trials": len(study.trials), | |
| }, indent=2)) | |
| return { | |
| "best_trial": best.number, | |
| "best_value": best.value, | |
| "best_params": best.params, | |
| "n_trials": len(study.trials), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| parser = argparse.ArgumentParser(description="TFT-ASRO hyperparameter optimisation") | |
| parser.add_argument("--n-trials", type=int, default=50) | |
| parser.add_argument("--study-name", default="tft_asro_optuna") | |
| args = parser.parse_args() | |
| result = run_hyperopt(n_trials=args.n_trials, study_name=args.study_name) | |
| print("\n" + "=" * 60) | |
| print("HYPEROPT COMPLETE") | |
| print("=" * 60) | |
| print(f"Best trial: #{result['best_trial']}") | |
| print(f"Best val_loss: {result['best_value']:.6f}") | |
| for k, v in result["best_params"].items(): | |
| print(f" {k}: {v}") | |