Spaces:
Running
Running
File size: 6,109 Bytes
e57e9d1 db6c149 e57e9d1 db6c149 e57e9d1 db6c149 e57e9d1 db6c149 e57e9d1 06d9b9c e57e9d1 06d9b9c e57e9d1 06d9b9c e57e9d1 06d9b9c 3fd2ce8 e57e9d1 06d9b9c e57e9d1 db6c149 e57e9d1 db6c149 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """
Central configuration for the TFT-ASRO deep learning pipeline.
All hyperparameters, feature dimensions, and training settings live here
so every module draws from a single source of truth.
Model paths honour the MODEL_DIR environment variable so they work both
locally (``data/models``) and inside the HF Space container
(``/data/models``).
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
def _model_dir() -> str:
"""Resolve the base model directory from env (same as app.settings)."""
return os.environ.get("MODEL_DIR", "/data/models")
@dataclass(frozen=True)
class EmbeddingConfig:
model_name: str = "ProsusAI/finbert"
full_dim: int = 768
pca_dim: int = 32
max_token_length: int = 512
batch_size: int = 64
pca_model_path: str = ""
@dataclass(frozen=True)
class SentimentFeatureConfig:
momentum_windows: tuple[int, ...] = (5, 10, 30)
surprise_lookback: int = 30
surprise_threshold: float = 2.0
event_types: tuple[str, ...] = (
"supply_disruption",
"supply_expansion",
"demand_increase",
"demand_decrease",
"inventory_draw",
"inventory_build",
"policy_support",
"policy_drag",
"macro_usd_up",
"macro_usd_down",
"cost_push",
)
@dataclass(frozen=True)
class LMEConfig:
nasdaq_api_key_env: str = "NASDAQ_DATA_LINK_API_KEY"
quandl_dataset: str = "LME/PR_CU"
stock_change_windows: tuple[int, ...] = (1, 5, 10, 20)
depletion_window: int = 20
futures_symbols: tuple[str, ...] = ("HG=F",)
futures_months_ahead: tuple[int, ...] = (3, 6, 12)
max_ffill_days: int = 5
@dataclass(frozen=True)
class TFTModelConfig:
max_encoder_length: int = 60
max_prediction_length: int = 5
# hidden_size 64β32: VSN encoder had 3.2M params for only 313 training
# samples (344 features Γ hidden_size Γ hidden_continuous_size).
# Reducing halves the dominant layer while keeping expressiveness.
hidden_size: int = 32
# attention_head_size 4β2: fewer heads for a small, single-series dataset.
attention_head_size: int = 2
# dropout 0.1β0.3: 313 samples / ~900K params still demands heavy regularisation.
dropout: float = 0.3
hidden_continuous_size: int = 16 # was 32; paired reduction with hidden_size
quantiles: tuple[float, ...] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98)
# lr 1e-3β3e-4: smaller batches produce noisier gradients; conservative LR
# reduces the risk of overshooting the narrow-loss landscape.
learning_rate: float = 3e-4
reduce_on_plateau_patience: int = 4
# clip 0.5β1.0: tanh-based Sharpe gradients are inherently bounded;
# relaxing the clip lets the model escape flat regions more aggressively.
gradient_clip_val: float = 1.0
@dataclass(frozen=True)
class ASROConfig:
# Total loss = lambda_quantile * calibration + (1-lambda_quantile) * sharpe
#
# lambda_quantile is the EXPLICIT weight of the quantile calibration bundle:
# calibration = q_loss + lambda_vol * vol_loss
# w_sharpe = 1 - lambda_quantile (the complementary directional weight)
#
# This normalised (sum-to-1) formulation makes both components interpretable
# and prevents either from silently dominating across loss-magnitude regimes.
#
# 0.4 / 0.6 split: 40% calibration (keeps TFT probabilistic),
# 60% Sharpe (drives directional / amplitude learning)
lambda_quantile: float = 0.4 # w_quantile; was 0.3 (unnormalised old formula)
# lambda_vol is a sub-weight within the calibration bundle only.
# It controls how much the Q90-Q10 spread tracks 2Γ actual Ο.
# Two independent Optuna runs (20 trials each) both converged on 0.35 β
# updating default to match confirmed optimal value.
lambda_vol: float = 0.35
risk_free_rate: float = 0.0
sharpe_window: int = 20
@dataclass(frozen=True)
class TrainingConfig:
max_epochs: int = 100
# patience 10β15: with 19 batches/epoch (vs 4 before) each epoch carries
# more information; give the model more time to converge.
early_stopping_patience: int = 15
# batch_size 64β16: 313 samples / 64 = 4 batches/epoch β noisy gradients.
# 313 / 16 β 19 batches/epoch gives stable, consistent gradient estimates.
batch_size: int = 16
val_ratio: float = 0.15
test_ratio: float = 0.10
lookback_days: int = 730
seed: int = 42
num_workers: int = 0
optuna_n_trials: int = 50
checkpoint_dir: str = ""
best_model_path: str = ""
hf_model_repo: str = "ifieryarrows/copper-mind-tft"
@dataclass(frozen=True)
class FeatureStoreConfig:
target_symbol: str = "HG=F"
max_ffill: int = 3
calendar_features: bool = True
macro_event_features: bool = True
@dataclass
class TFTASROConfig:
"""Top-level config aggregating all sub-configs."""
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
sentiment: SentimentFeatureConfig = field(default_factory=SentimentFeatureConfig)
lme: LMEConfig = field(default_factory=LMEConfig)
model: TFTModelConfig = field(default_factory=TFTModelConfig)
asro: ASROConfig = field(default_factory=ASROConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
feature_store: FeatureStoreConfig = field(default_factory=FeatureStoreConfig)
@property
def model_root(self) -> Path:
return Path(self.training.checkpoint_dir).parent
def get_tft_config() -> TFTASROConfig:
"""
Return the default TFT-ASRO configuration with paths resolved from
MODEL_DIR (``/data/models`` on HF Space, configurable locally).
"""
base = Path(_model_dir()) / "tft"
return TFTASROConfig(
embedding=EmbeddingConfig(
pca_model_path=str(base / "pca_finbert.joblib"),
),
training=TrainingConfig(
checkpoint_dir=str(base / "checkpoints"),
best_model_path=str(base / "best_tft_asro.ckpt"),
),
)
|