""" StockPro TFT Training Script — Google Colab Edition ===================================================== Uses pytorch-forecasting's production TFT implementation. Usage in Colab: # 1. Install deps !pip install pytorch-forecasting pytorch-lightning huggingface_hub requests pandas numpy # 2. Mount Drive (optional, for checkpoint persistence) from google.colab import drive drive.mount('/content/drive') # 3. Run !python train_colab.py # Or set env vars for HF Hub upload: import os os.environ["HF_TOKEN"] = "hf_xxx" os.environ["HF_MODEL_REPO"] = "username/stockpro-tft" !python train_colab.py """ import os import sys import time import requests import numpy as np import pandas as pd from datetime import datetime, timedelta from typing import Optional # ── Install check ───────────────────────────────────────────────────────────── try: import pytorch_forecasting # noqa: F401 import lightning.pytorch as pl except ImportError: print("Installing pytorch-forecasting and lightning...") os.system("pip install -q pytorch-forecasting pytorch-lightning") import pytorch_forecasting # noqa: F401 import lightning.pytorch as pl import torch from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer from pytorch_forecasting.metrics import QuantileLoss from torch.utils.data import DataLoader # ── Config ──────────────────────────────────────────────────────────────────── ENCODER_LENGTH = 60 # lookback window PREDICTION_LENGTH = 30 # max forecast horizon BATCH_SIZE = 64 MAX_EPOCHS = 50 LR = 1e-3 GRADIENT_CLIP = 0.1 HIDDEN_SIZE = 64 # TFT hidden state ATTENTION_HEAD_SIZE = 4 DROPOUT = 0.1 HIDDEN_CONTINUOUS_SIZE = 16 # Save paths — adjust for Colab/Drive as needed MODEL_DIR = os.environ.get("MODEL_DIR", "/content/models") CHECKPOINT_DIR = os.path.join(MODEL_DIR, "checkpoints") FINAL_MODEL_PATH = os.path.join(MODEL_DIR, "tft_stock.ckpt") DATASET_PARAMS_PATH = os.path.join(MODEL_DIR, "dataset_params.pt") DDG_DA_PATH = os.path.join(MODEL_DIR, "ddg_da.pt") os.makedirs(MODEL_DIR, exist_ok=True) os.makedirs(CHECKPOINT_DIR, exist_ok=True) # ── Full IDX ticker list (956 stocks, source: IDX Daftar Saham 2026-03-08) ──── IDX_TICKERS = [ "AADI", "AALI", "ABBA", "ABDA", "ABMM", "ACES", "ACRO", "ACST", "ADCP", "ADES", "ADHI", "ADMF", "ADMG", "ADMR", "ADRO", "AEGS", "AGAR", "AGII", "AGRO", "AGRS", "AHAP", "AIMS", "AISA", "AKKU", "AKPI", "AKRA", "AKSI", "ALDO", "ALII", "ALKA", "ALMI", "ALTO", "AMAG", "AMAN", "AMAR", "AMFG", "AMIN", "AMMN", "AMMS", "AMOR", "AMRT", "ANDI", "ANJT", "ANTM", "APEX", "APIC", "APII", "APLI", "APLN", "ARCI", "AREA", "ARGO", "ARII", "ARKA", "ARKO", "ARMY", "ARNA", "ARTA", "ARTI", "ARTO", "ASBI", "ASDM", "ASGR", "ASHA", "ASII", "ASJT", "ASLC", "ASLI", "ASMI", "ASPI", "ASPR", "ASRI", "ASRM", "ASSA", "ATAP", "ATIC", "ATLA", "AUTO", "AVIA", "AWAN", "AXIO", "AYAM", "AYLS", "BABP", "BABY", "BACA", "BAIK", "BAJA", "BALI", "BANK", "BAPA", "BAPI", "BATA", "BATR", "BAUT", "BAYU", "BBCA", "BBHI", "BBKP", "BBLD", "BBMD", "BBNI", "BBRI", "BBRM", "BBSI", "BBSS", "BBTN", "BBYB", "BCAP", "BCIC", "BCIP", "BDKR", "BDMN", "BEBS", "BEEF", "BEER", "BEKS", "BELI", "BELL", "BESS", "BEST", "BFIN", "BGTG", "BHAT", "BHIT", "BIKA", "BIKE", "BIMA", "BINA", "BINO", "BIPI", "BIPP", "BIRD", "BISI", "BJBR", "BJTM", "BKDP", "BKSL", "BKSW", "BLES", "BLOG", "BLTA", "BLTZ", "BLUE", "BMAS", "BMBL", "BMHS", "BMRI", "BMSR", "BMTR", "BNBA", "BNBR", "BNGA", "BNII", "BNLI", "BOAT", "BOBA", "BOGA", "BOLA", "BOLT", "BOSS", "BPFI", "BPII", "BPTR", "BRAM", "BREN", "BRIS", "BRMS", "BRNA", "BRPT", "BRRC", "BSBK", "BSDE", "BSIM", "BSML", "BSSR", "BSWD", "BTEK", "BTEL", "BTON", "BTPN", "BTPS", "BUAH", "BUDI", "BUKA", "BUKK", "BULL", "BUMI", "BUVA", "BVIC", "BWPT", "BYAN", "CAKK", "CAMP", "CANI", "CARE", "CARS", "CASA", "CASH", "CASS", "CBDK", "CBMF", "CBPE", "CBRE", "CBUT", "CCSI", "CDIA", "CEKA", "CENT", "CFIN", "CGAS", "CHEK", "CHEM", "CHIP", "CINT", "CITA", "CITY", "CLAY", "CLEO", "CLPI", "CMNP", "CMNT", "CMPP", "CMRY", "CNKO", "CNMA", "CNTX", "COAL", "COCO", "COIN", "COWL", "CPIN", "CPRI", "CPRO", "CRAB", "CRSN", "CSAP", "CSIS", "CSMI", "CSRA", "CTBN", "CTRA", "CTTH", "CUAN", "CYBR", "DAAZ", "DADA", "DART", "DATA", "DAYA", "DCII", "DEAL", "DEFI", "DEPO", "DEWA", "DEWI", "DFAM", "DGIK", "DGNS", "DGWG", "DIGI", "DILD", "DIVA", "DKFT", "DKHH", "DLTA", "DMAS", "DMMX", "DMND", "DNAR", "DNET", "DOID", "DOOH", "DOSS", "DPNS", "DPUM", "DRMA", "DSFI", "DSNG", "DSSA", "DUCK", "DUTI", "DVLA", "DWGL", "DYAN", "EAST", "ECII", "EDGE", "EKAD", "ELIT", "ELPI", "ELSA", "ELTY", "EMAS", "EMDE", "EMTK", "ENAK", "ENRG", "ENVY", "ENZO", "EPAC", "EPMT", "ERAA", "ERAL", "ERTX", "ESIP", "ESSA", "ESTA", "ESTI", "ETWA", "EURO", "EXCL", "FAPA", "FAST", "FASW", "FILM", "FIMP", "FIRE", "FISH", "FITT", "FLMC", "FMII", "FOLK", "FOOD", "FORE", "FORU", "FPNI", "FUJI", "FUTR", "FWCT", "GAMA", "GDST", "GDYR", "GEMA", "GEMS", "GGRM", "GGRP", "GHON", "GIAA", "GJTL", "GLOB", "GLVA", "GMFI", "GMTD", "GOLD", "GOLF", "GOLL", "GOOD", "GOTO", "GPRA", "GPSO", "GRIA", "GRPH", "GRPM", "GSMF", "GTBO", "GTRA", "GTSI", "GULA", "GUNA", "GWSA", "GZCO", "HADE", "HAIS", "HAJJ", "HALO", "HATM", "HBAT", "HDFA", "HDIT", "HEAL", "HELI", "HERO", "HEXA", "HGII", "HILL", "HITS", "HKMU", "HMSP", "HOKI", "HOME", "HOMI", "HOPE", "HOTL", "HRME", "HRTA", "HRUM", "HUMI", "HYGN", "IATA", "IBFN", "IBOS", "IBST", "ICBP", "ICON", "IDEA", "IDPR", "IFII", "IFSH", "IGAR", "IIKP", "IKAI", "IKAN", "IKBI", "IKPM", "IMAS", "IMJS", "IMPC", "INAF", "INAI", "INCF", "INCI", "INCO", "INDF", "INDO", "INDR", "INDS", "INDX", "INDY", "INET", "INKP", "INOV", "INPC", "INPP", "INPS", "INRU", "INTA", "INTD", "INTP", "IOTF", "IPAC", "IPCC", "IPCM", "IPOL", "IPPE", "IPTV", "IRRA", "IRSX", "ISAP", "ISAT", "ISEA", "ISSP", "ITIC", "ITMA", "ITMG", "JARR", "JAST", "JATI", "JAWA", "JAYA", "JECC", "JGLE", "JIHD", "JKON", "JMAS", "JPFA", "JRPT", "JSKY", "JSMR", "JSPT", "JTPE", "KAEF", "KAQI", "KARW", "KAYU", "KBAG", "KBLI", "KBLM", "KBLV", "KBRI", "KDSI", "KDTN", "KEEN", "KEJU", "KETR", "KIAS", "KICI", "KIJA", "KING", "KINO", "KIOS", "KJEN", "KKES", "KKGI", "KLAS", "KLBF", "KLIN", "KMDS", "KMTR", "KOBX", "KOCI", "KOIN", "KOKA", "KONI", "KOPI", "KOTA", "KPIG", "KRAS", "KREN", "KRYA", "KSIX", "KUAS", "LABA", "LABS", "LAJU", "LAND", "LAPD", "LCGP", "LCKM", "LEAD", "LFLO", "LIFE", "LINK", "LION", "LIVE", "LMAS", "LMAX", "LMPI", "LMSH", "LOPI", "LPCK", "LPGI", "LPIN", "LPKR", "LPLI", "LPPF", "LPPS", "LRNA", "LSIP", "LTLS", "LUCK", "LUCY", "MABA", "MAGP", "MAHA", "MAIN", "MANG", "MAPA", "MAPB", "MAPI", "MARI", "MARK", "MASB", "MAXI", "MAYA", "MBAP", "MBMA", "MBSS", "MBTO", "MCAS", "MCOL", "MCOR", "MDIA", "MDIY", "MDKA", "MDKI", "MDLA", "MDLN", "MDRN", "MEDC", "MEDS", "MEGA", "MEJA", "MENN", "MERI", "MERK", "META", "MFMI", "MGLV", "MGNA", "MGRO", "MHKI", "MICE", "MIDI", "MIKA", "MINA", "MINE", "MIRA", "MITI", "MKAP", "MKNT", "MKPI", "MKTR", "MLBI", "MLIA", "MLPL", "MLPT", "MMIX", "MMLP", "MNCN", "MOLI", "MORA", "MPIX", "MPMX", "MPOW", "MPPA", "MPRO", "MPXL", "MRAT", "MREI", "MSIE", "MSIN", "MSJA", "MSKY", "MSTI", "MTDL", "MTEL", "MTFN", "MTLA", "MTMH", "MTPS", "MTRA", "MTSM", "MTWI", "MUTU", "MYOH", "MYOR", "MYTX", "NAIK", "NANO", "NASA", "NASI", "NATO", "NAYZ", "NCKL", "NELY", "NEST", "NETV", "NFCX", "NICE", "NICK", "NICL", "NIKL", "NINE", "NIRO", "NISP", "NOBU", "NPGF", "NRCA", "NSSS", "NTBK", "NUSA", "NZIA", "OASA", "OBAT", "OBMD", "OCAP", "OILS", "OKAS", "OLIV", "OMED", "OMRE", "OPMS", "PACK", "PADA", "PADI", "PALM", "PAMG", "PANI", "PANR", "PANS", "PART", "PBID", "PBRX", "PBSA", "PCAR", "PDES", "PDPP", "PEGE", "PEHA", "PEVE", "PGAS", "PGEO", "PGJO", "PGLI", "PGUN", "PICO", "PIPA", "PJAA", "PJHB", "PKPK", "PLAN", "PLAS", "PLIN", "PMJS", "PMMP", "PMUI", "PNBN", "PNBS", "PNGO", "PNIN", "PNLF", "PNSE", "POLA", "POLI", "POLL", "POLU", "POLY", "POOL", "PORT", "POSA", "POWR", "PPGL", "PPRE", "PPRI", "PPRO", "PRAY", "PRDA", "PRIM", "PSAB", "PSAT", "PSDN", "PSGO", "PSKT", "PSSI", "PTBA", "PTDU", "PTIS", "PTMP", "PTMR", "PTPP", "PTPS", "PTPW", "PTRO", "PTSN", "PTSP", "PUDP", "PURA", "PURE", "PURI", "PWON", "PYFA", "PZZA", "RAAM", "RAFI", "RAJA", "RALS", "RANC", "RATU", "RBMS", "RCCC", "RDTX", "REAL", "RELF", "RELI", "RGAS", "RICY", "RIGS", "RIMO", "RISE", "RLCO", "RMKE", "RMKO", "ROCK", "RODA", "RONY", "ROTI", "RSCH", "RSGK", "RUIS", "RUNS", "SAFE", "SAGE", "SAME", "SAMF", "SAPX", "SATU", "SBAT", "SBMA", "SCCO", "SCMA", "SCNP", "SCPI", "SDMU", "SDPC", "SDRA", "SEMA", "SFAN", "SGER", "SGRO", "SHID", "SHIP", "SICO", "SIDO", "SILO", "SIMA", "SIMP", "SINI", "SIPD", "SKBM", "SKLT", "SKRN", "SKYB", "SLIS", "SMAR", "SMBR", "SMCB", "SMDM", "SMDR", "SMGA", "SMGR", "SMIL", "SMKL", "SMKM", "SMLE", "SMMA", "SMMT", "SMRA", "SMRU", "SMSM", "SNLK", "SOCI", "SOFA", "SOHO", "SOLA", "SONA", "SOSS", "SOTS", "SOUL", "SPMA", "SPRE", "SPTO", "SQMI", "SRAJ", "SRIL", "SRSN", "SRTG", "SSIA", "SSMS", "SSTM", "STAA", "STAR", "STRK", "STTP", "SUGI", "SULI", "SUNI", "SUPA", "SUPR", "SURE", "SURI", "SWAT", "SWID", "TALF", "TAMA", "TAMU", "TAPG", "TARA", "TAXI", "TAYS", "TBIG", "TBLA", "TBMS", "TCID", "TCPI", "TDPM", "TEBE", "TECH", "TELE", "TFAS", "TFCO", "TGKA", "TGRA", "TGUK", "TIFA", "TINS", "TIRA", "TIRT", "TKIM", "TLDN", "TLKM", "TMAS", "TMPO", "TNCA", "TOBA", "TOOL", "TOPS", "TOSK", "TOTL", "TOTO", "TOWR", "TOYS", "TPIA", "TPMA", "TRAM", "TRGU", "TRIL", "TRIM", "TRIN", "TRIO", "TRIS", "TRJA", "TRON", "TRST", "TRUE", "TRUK", "TRUS", "TSPC", "TUGU", "TYRE", "UANG", "UCID", "UDNG", "UFOE", "ULTJ", "UNIC", "UNIQ", "UNIT", "UNSP", "UNTD", "UNTR", "UNVR", "URBN", "UVCR", "VAST", "VERN", "VICI", "VICO", "VINS", "VISI", "VIVA", "VKTR", "VOKS", "VRNA", "VTNY", "WAPO", "WEGE", "WEHA", "WGSH", "WICO", "WIDI", "WIFI", "WIIM", "WIKA", "WINE", "WINR", "WINS", "WIRG", "WMPP", "WMUU", "WOMF", "WOOD", "WOWS", "WSBP", "WSKT", "WTON", "YELO", "YOII", "YPAS", "YULE", "YUPI", "ZATA", "ZBRA", "ZINC", "ZONE", "ZYRX", ] INDOPREMIER_URL = "https://www.indopremier.com/module/saham/include/json-charting.php" _session = requests.Session() _session.headers.update({ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Chrome/124.0.0.0 Safari/537.36", "Accept": "application/json, text/plain, */*", "Referer": "https://www.indopremier.com/", }) # ── Data fetching ───────────────────────────────────────────────────────────── def fetch_ohlcv(symbol: str, years: int = 5) -> Optional[pd.DataFrame]: """Fetch OHLCV from IndoPremier, return DataFrame with date/open/high/low/close/volume.""" end = datetime.now() start = end - timedelta(days=years * 365) fmt = lambda d: d.strftime("%m/%d/%Y") try: resp = _session.get( INDOPREMIER_URL, params={"code": symbol, "start": fmt(start), "end": fmt(end)}, timeout=15, ) resp.raise_for_status() raw = resp.json() if not isinstance(raw, list) or len(raw) < 60: return None df = pd.DataFrame(raw, columns=["timestamp_ms", "open", "high", "low", "close", "volume"]) df["date"] = pd.to_datetime(df["timestamp_ms"], unit="ms", utc=True).dt.tz_localize(None) df = df[df["close"] > 0].drop(columns=["timestamp_ms"]) df = df.sort_values("date").reset_index(drop=True) return df if len(df) >= 60 else None except Exception as e: return None # ── Feature engineering ─────────────────────────────────────────────────────── def add_features(df: pd.DataFrame) -> pd.DataFrame: """Add normalized technical indicators as columns.""" c = df["close"].values.astype(np.float64) v = df["volume"].values.astype(np.float64) # Rolling 30-day z-score normalization for price s = pd.Series(c) rm = s.rolling(30, min_periods=1).mean() rs = s.rolling(30, min_periods=1).std(ddof=0).fillna(1).clip(lower=1e-6) df["close_norm"] = ((s - rm) / rs).values # Volume norm sv = pd.Series(v) vm = sv.rolling(30, min_periods=1).mean() vs = sv.rolling(30, min_periods=1).std(ddof=0).fillna(1).clip(lower=1e-6) df["volume_norm"] = ((sv - vm) / vs).values # RSI (normalized 0-1) delta = np.diff(c, prepend=c[0]) gain = pd.Series(np.where(delta > 0, delta, 0.0)).ewm(alpha=1/14, adjust=False).mean().values loss = pd.Series(np.where(delta < 0, -delta, 0.0)).ewm(alpha=1/14, adjust=False).mean().values rs_ratio = np.where(loss == 0, 100.0, gain / (loss + 1e-9)) df["rsi"] = np.clip(rs_ratio / (1 + rs_ratio), 0, 1) # MACD histogram (normalized) sp = pd.Series(c) macd_line = sp.ewm(span=12, adjust=False).mean() - sp.ewm(span=26, adjust=False).mean() signal = macd_line.ewm(span=9, adjust=False).mean() macd_raw = (macd_line - signal).values df["macd_norm"] = macd_raw / (np.std(macd_raw) or 1) # Bollinger band width sma = sp.rolling(20, min_periods=1).mean() std20 = sp.rolling(20, min_periods=1).std(ddof=0).fillna(0) df["bb_width"] = (2 * std20 / sma.clip(lower=1e-6)).fillna(0).values # ATR normalized h, l_col = df["high"].values, df["low"].values prev_c = np.roll(c, 1); prev_c[0] = c[0] tr = np.maximum(h - l_col, np.maximum(np.abs(h - prev_c), np.abs(l_col - prev_c))) atr = pd.Series(tr).ewm(alpha=1/14, adjust=False).mean().values df["atr_norm"] = atr / (c + 1e-9) # OBV normalized direction = np.sign(np.diff(c, prepend=c[0])) obv = np.cumsum(direction * v) obv_std = np.std(obv) or 1 df["obv_norm"] = (obv - np.mean(obv)) / obv_std # Cyclical encodings df["day_sin"] = np.sin(2 * np.pi * df["date"].dt.dayofweek / 5) df["day_cos"] = np.cos(2 * np.pi * df["date"].dt.dayofweek / 5) df["month_sin"] = np.sin(2 * np.pi * df["date"].dt.month / 12) df["month_cos"] = np.cos(2 * np.pi * df["date"].dt.month / 12) return df.dropna().reset_index(drop=True) # ── Dataset assembly ────────────────────────────────────────────────────────── KNOWN_REALS = ["day_sin", "day_cos", "month_sin", "month_cos"] UNKNOWN_REALS = ["close_norm", "volume_norm", "rsi", "macd_norm", "bb_width", "atr_norm", "obv_norm"] TARGET = "close_norm" def collect_data( tickers: list[str], years: int = 5, delay: float = 0.2, max_failures: int = 400, ) -> pd.DataFrame: """Fetch and concatenate all tickers into a single DataFrame for TimeSeriesDataSet.""" all_dfs = [] failures = 0 print(f"Fetching {len(tickers)} tickers from IndoPremier ({years}y history)...") for i, ticker in enumerate(tickers): df = fetch_ohlcv(ticker, years=years) if df is None: failures += 1 if (i + 1) % 50 == 0: print(f" [{i+1}/{len(tickers)}] {len(all_dfs)} valid, {failures} failed") if failures >= max_failures: print(f" Too many failures ({failures}), stopping early.") break time.sleep(delay) continue df = add_features(df) df["ticker"] = ticker df["time_idx"] = np.arange(len(df)) all_dfs.append(df[["ticker", "time_idx", "date", TARGET] + KNOWN_REALS + UNKNOWN_REALS]) if (i + 1) % 50 == 0: print(f" [{i+1}/{len(tickers)}] {len(all_dfs)} valid, {failures} failed") time.sleep(delay) if not all_dfs: raise RuntimeError("No data collected from IndoPremier.") combined = pd.concat(all_dfs, ignore_index=True) print(f"\nTotal rows: {len(combined):,} across {len(all_dfs)} tickers") return combined # ── pytorch-forecasting DataSet + Model ─────────────────────────────────────── def build_dataset(df: pd.DataFrame, predict: bool = False) -> TimeSeriesDataSet: """Build pytorch-forecasting TimeSeriesDataSet.""" # Use last PREDICTION_LENGTH rows of each ticker for validation training_cutoff = df.groupby("ticker")["time_idx"].transform( lambda x: x.max() - PREDICTION_LENGTH ) subset = df[df["time_idx"] <= training_cutoff] if not predict else df return TimeSeriesDataSet( subset, time_idx="time_idx", target=TARGET, group_ids=["ticker"], min_encoder_length=ENCODER_LENGTH // 2, max_encoder_length=ENCODER_LENGTH, min_prediction_length=1, max_prediction_length=PREDICTION_LENGTH, static_categoricals=["ticker"], time_varying_known_reals=KNOWN_REALS + ["time_idx"], time_varying_unknown_reals=UNKNOWN_REALS, target_normalizer=None, # already normalized add_relative_time_idx=True, add_target_scales=True, add_encoder_length=True, allow_missing_timesteps=True, ) def build_tft(training_dataset: TimeSeriesDataSet) -> TemporalFusionTransformer: return TemporalFusionTransformer.from_dataset( training_dataset, learning_rate=LR, hidden_size=HIDDEN_SIZE, attention_head_size=ATTENTION_HEAD_SIZE, dropout=DROPOUT, hidden_continuous_size=HIDDEN_CONTINUOUS_SIZE, loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]), log_interval=10, reduce_on_plateau_patience=4, optimizer="adam", ) # ── DDG-DA drift predictor training (reused from train.py logic) ────────────── def train_ddg_da_predictor(df: pd.DataFrame) -> None: """Train the lightweight MLP drift predictor for DDG-DA.""" import torch.nn as nn from torch.utils.data import TensorDataset, DataLoader SNAPSHOT_WINDOW = 20 K_HISTORY = 8 N_FEATURES = len(UNKNOWN_REALS) + len(KNOWN_REALS) # 11 SNAPSHOT_DIM = N_FEATURES * 4 # 44 def compute_snapshot(arr: np.ndarray) -> np.ndarray: stats = [] for f in range(arr.shape[1]): col = arr[:, f].astype(np.float64) mu = col.mean(); sigma = col.std() or 1e-8 skew = float(((col - mu) ** 3).mean() / sigma ** 3) kurt = float(((col - mu) ** 4).mean() / sigma ** 4) - 3.0 stats.extend([float(mu), float(sigma), skew, kurt]) return np.array(stats, dtype=np.float32) feature_cols = UNKNOWN_REALS + KNOWN_REALS all_X, all_y = [], [] for ticker, grp in df.groupby("ticker"): feat = grp[feature_cols].values.astype(np.float32) k = len(feat) // SNAPSHOT_WINDOW if k < K_HISTORY + 1: continue snaps = np.stack([compute_snapshot(feat[i*SNAPSHOT_WINDOW:(i+1)*SNAPSHOT_WINDOW]) for i in range(k)]) for i in range(len(snaps) - K_HISTORY): all_X.append(snaps[i:i+K_HISTORY].flatten()) all_y.append(snaps[i+K_HISTORY]) if not all_X: print("Not enough data for DDG-DA training, skipping.") return X = np.array(all_X, dtype=np.float32) y = np.array(all_y, dtype=np.float32) idx = np.random.permutation(len(X)) X, y = X[idx], y[idx] split = int(len(X) * 0.9) train_ds = TensorDataset(torch.tensor(X[:split]), torch.tensor(y[:split])) val_ds = TensorDataset(torch.tensor(X[split:]), torch.tensor(y[split:])) train_dl = DataLoader(train_ds, batch_size=256, shuffle=True) val_dl = DataLoader(val_ds, batch_size=256) mlp = nn.Sequential( nn.Linear(K_HISTORY * SNAPSHOT_DIM, 128), nn.ELU(), nn.Dropout(0.1), nn.Linear(128, SNAPSHOT_DIM), ) opt = torch.optim.Adam(mlp.parameters(), lr=1e-3) crit = nn.MSELoss() best_val = float("inf") print(f"\n--- Training DDG-DA MLP ({len(X):,} pairs) ---") for epoch in range(1, 31): mlp.train() tl = sum(crit(mlp(xb), yb).item() * len(xb) for xb, yb in train_dl) / len(X[:split]) mlp.eval() with torch.no_grad(): vl = sum(crit(mlp(xb), yb).item() * len(xb) for xb, yb in val_dl) / len(X[split:]) print(f" Epoch {epoch:2d}/30 train={tl:.6f} val={vl:.6f}") if vl < best_val: best_val = vl torch.save(mlp.state_dict(), DDG_DA_PATH) print(f"DDG-DA saved → {DDG_DA_PATH}") # ── Upload to HF Hub ────────────────────────────────────────────────────────── def upload_to_hub(checkpoint_path: str) -> None: hf_token = os.environ.get("HF_TOKEN", "") hf_repo = os.environ.get("HF_MODEL_REPO", "") if not hf_token or not hf_repo: print("HF_TOKEN / HF_MODEL_REPO not set — skipping upload.") return from huggingface_hub import HfApi api = HfApi(token=hf_token) api.create_repo(repo_id=hf_repo, repo_type="model", exist_ok=True, private=True) for local, remote in [ (checkpoint_path, "tft_stock.ckpt"), (DATASET_PARAMS_PATH, "dataset_params.pt"), (DDG_DA_PATH, "ddg_da.pt"), ]: if os.path.exists(local): api.upload_file(path_or_fileobj=local, path_in_repo=remote, repo_id=hf_repo, repo_type="model", commit_message=f"Colab retrain: {remote}") print(f"Uploaded {remote} → {hf_repo}") # ── Main ────────────────────────────────────────────────────────────────────── def main(): print("=" * 60) print("StockPro TFT Training (pytorch-forecasting)") print("=" * 60) # 1. Collect data df = collect_data(IDX_TICKERS, years=5) df.to_parquet(os.path.join(MODEL_DIR, "training_data.parquet"), index=False) print(f"Data saved → {MODEL_DIR}/training_data.parquet") # 2. Build datasets training_ds = build_dataset(df, predict=False) # Save dataset parameters (fitted categorical encoders, scalers, config) # These are required by the HF Spaces inference server at runtime. torch.save(training_ds.get_parameters(), DATASET_PARAMS_PATH) print(f"Dataset params saved → {DATASET_PARAMS_PATH}") val_ds = training_ds.from_dataset(training_ds, df, predict=True, stop_randomization=True) train_dl = training_ds.to_dataloader(train=True, batch_size=BATCH_SIZE, num_workers=2) val_dl = val_ds.to_dataloader(train=False, batch_size=BATCH_SIZE * 2, num_workers=2) # 3. Build TFT model tft = build_tft(training_ds) print(f"\nTFT parameters: {sum(p.numel() for p in tft.parameters()):,}") # 4. Train with Lightning from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint callbacks = [ EarlyStopping(monitor="val_loss", patience=5, mode="min"), ModelCheckpoint( dirpath=CHECKPOINT_DIR, filename="tft-{epoch:02d}-{val_loss:.4f}", monitor="val_loss", mode="min", save_top_k=1, ), ] trainer = pl.Trainer( max_epochs=MAX_EPOCHS, accelerator="auto", # GPU if available gradient_clip_val=GRADIENT_CLIP, callbacks=callbacks, enable_progress_bar=True, log_every_n_steps=10, ) print("\n--- Training TFT ---") trainer.fit(tft, train_dataloaders=train_dl, val_dataloaders=val_dl) # 5. Copy best checkpoint best_ckpt = callbacks[1].best_model_path if best_ckpt and os.path.exists(best_ckpt): import shutil shutil.copy2(best_ckpt, FINAL_MODEL_PATH) print(f"\nBest model → {FINAL_MODEL_PATH}") # 6. Train DDG-DA train_ddg_da_predictor(df) # 7. Upload to HF Hub upload_to_hub(FINAL_MODEL_PATH) print("\nDone!") if __name__ == "__main__": main()