Spaces:
Sleeping
Sleeping
| """ | |
| StockPro TFT Training Script β Google Colab Edition | |
| ===================================================== | |
| Uses pytorch-forecasting's production TFT implementation. | |
| Usage in Colab: | |
| # 1. Install deps | |
| !pip install pytorch-forecasting pytorch-lightning huggingface_hub requests pandas numpy | |
| # 2. Mount Drive (optional, for checkpoint persistence) | |
| from google.colab import drive | |
| drive.mount('/content/drive') | |
| # 3. Run | |
| !python train_colab.py | |
| # Or set env vars for HF Hub upload: | |
| import os | |
| os.environ["HF_TOKEN"] = "hf_xxx" | |
| os.environ["HF_MODEL_REPO"] = "username/stockpro-tft" | |
| !python train_colab.py | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import requests | |
| import numpy as np | |
| import pandas as pd | |
| from datetime import datetime, timedelta | |
| from typing import Optional | |
| # ββ Install check βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| import pytorch_forecasting # noqa: F401 | |
| import lightning.pytorch as pl | |
| except ImportError: | |
| print("Installing pytorch-forecasting and lightning...") | |
| os.system("pip install -q pytorch-forecasting pytorch-lightning") | |
| import pytorch_forecasting # noqa: F401 | |
| import lightning.pytorch as pl | |
| import torch | |
| from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer | |
| from pytorch_forecasting.metrics import QuantileLoss | |
| from torch.utils.data import DataLoader | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ENCODER_LENGTH = 60 # lookback window | |
| PREDICTION_LENGTH = 30 # max forecast horizon | |
| BATCH_SIZE = 64 | |
| MAX_EPOCHS = 50 | |
| LR = 1e-3 | |
| GRADIENT_CLIP = 0.1 | |
| HIDDEN_SIZE = 64 # TFT hidden state | |
| ATTENTION_HEAD_SIZE = 4 | |
| DROPOUT = 0.1 | |
| HIDDEN_CONTINUOUS_SIZE = 16 | |
| # Save paths β adjust for Colab/Drive as needed | |
| MODEL_DIR = os.environ.get("MODEL_DIR", "/content/models") | |
| CHECKPOINT_DIR = os.path.join(MODEL_DIR, "checkpoints") | |
| FINAL_MODEL_PATH = os.path.join(MODEL_DIR, "tft_stock.ckpt") | |
| DATASET_PARAMS_PATH = os.path.join(MODEL_DIR, "dataset_params.pt") | |
| DDG_DA_PATH = os.path.join(MODEL_DIR, "ddg_da.pt") | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| # ββ Full IDX ticker list (956 stocks, source: IDX Daftar Saham 2026-03-08) ββββ | |
| IDX_TICKERS = [ | |
| "AADI", "AALI", "ABBA", "ABDA", "ABMM", "ACES", "ACRO", "ACST", "ADCP", "ADES", | |
| "ADHI", "ADMF", "ADMG", "ADMR", "ADRO", "AEGS", "AGAR", "AGII", "AGRO", "AGRS", | |
| "AHAP", "AIMS", "AISA", "AKKU", "AKPI", "AKRA", "AKSI", "ALDO", "ALII", "ALKA", | |
| "ALMI", "ALTO", "AMAG", "AMAN", "AMAR", "AMFG", "AMIN", "AMMN", "AMMS", "AMOR", | |
| "AMRT", "ANDI", "ANJT", "ANTM", "APEX", "APIC", "APII", "APLI", "APLN", "ARCI", | |
| "AREA", "ARGO", "ARII", "ARKA", "ARKO", "ARMY", "ARNA", "ARTA", "ARTI", "ARTO", | |
| "ASBI", "ASDM", "ASGR", "ASHA", "ASII", "ASJT", "ASLC", "ASLI", "ASMI", "ASPI", | |
| "ASPR", "ASRI", "ASRM", "ASSA", "ATAP", "ATIC", "ATLA", "AUTO", "AVIA", "AWAN", | |
| "AXIO", "AYAM", "AYLS", "BABP", "BABY", "BACA", "BAIK", "BAJA", "BALI", "BANK", | |
| "BAPA", "BAPI", "BATA", "BATR", "BAUT", "BAYU", "BBCA", "BBHI", "BBKP", "BBLD", | |
| "BBMD", "BBNI", "BBRI", "BBRM", "BBSI", "BBSS", "BBTN", "BBYB", "BCAP", "BCIC", | |
| "BCIP", "BDKR", "BDMN", "BEBS", "BEEF", "BEER", "BEKS", "BELI", "BELL", "BESS", | |
| "BEST", "BFIN", "BGTG", "BHAT", "BHIT", "BIKA", "BIKE", "BIMA", "BINA", "BINO", | |
| "BIPI", "BIPP", "BIRD", "BISI", "BJBR", "BJTM", "BKDP", "BKSL", "BKSW", "BLES", | |
| "BLOG", "BLTA", "BLTZ", "BLUE", "BMAS", "BMBL", "BMHS", "BMRI", "BMSR", "BMTR", | |
| "BNBA", "BNBR", "BNGA", "BNII", "BNLI", "BOAT", "BOBA", "BOGA", "BOLA", "BOLT", | |
| "BOSS", "BPFI", "BPII", "BPTR", "BRAM", "BREN", "BRIS", "BRMS", "BRNA", "BRPT", | |
| "BRRC", "BSBK", "BSDE", "BSIM", "BSML", "BSSR", "BSWD", "BTEK", "BTEL", "BTON", | |
| "BTPN", "BTPS", "BUAH", "BUDI", "BUKA", "BUKK", "BULL", "BUMI", "BUVA", "BVIC", | |
| "BWPT", "BYAN", "CAKK", "CAMP", "CANI", "CARE", "CARS", "CASA", "CASH", "CASS", | |
| "CBDK", "CBMF", "CBPE", "CBRE", "CBUT", "CCSI", "CDIA", "CEKA", "CENT", "CFIN", | |
| "CGAS", "CHEK", "CHEM", "CHIP", "CINT", "CITA", "CITY", "CLAY", "CLEO", "CLPI", | |
| "CMNP", "CMNT", "CMPP", "CMRY", "CNKO", "CNMA", "CNTX", "COAL", "COCO", "COIN", | |
| "COWL", "CPIN", "CPRI", "CPRO", "CRAB", "CRSN", "CSAP", "CSIS", "CSMI", "CSRA", | |
| "CTBN", "CTRA", "CTTH", "CUAN", "CYBR", "DAAZ", "DADA", "DART", "DATA", "DAYA", | |
| "DCII", "DEAL", "DEFI", "DEPO", "DEWA", "DEWI", "DFAM", "DGIK", "DGNS", "DGWG", | |
| "DIGI", "DILD", "DIVA", "DKFT", "DKHH", "DLTA", "DMAS", "DMMX", "DMND", "DNAR", | |
| "DNET", "DOID", "DOOH", "DOSS", "DPNS", "DPUM", "DRMA", "DSFI", "DSNG", "DSSA", | |
| "DUCK", "DUTI", "DVLA", "DWGL", "DYAN", "EAST", "ECII", "EDGE", "EKAD", "ELIT", | |
| "ELPI", "ELSA", "ELTY", "EMAS", "EMDE", "EMTK", "ENAK", "ENRG", "ENVY", "ENZO", | |
| "EPAC", "EPMT", "ERAA", "ERAL", "ERTX", "ESIP", "ESSA", "ESTA", "ESTI", "ETWA", | |
| "EURO", "EXCL", "FAPA", "FAST", "FASW", "FILM", "FIMP", "FIRE", "FISH", "FITT", | |
| "FLMC", "FMII", "FOLK", "FOOD", "FORE", "FORU", "FPNI", "FUJI", "FUTR", "FWCT", | |
| "GAMA", "GDST", "GDYR", "GEMA", "GEMS", "GGRM", "GGRP", "GHON", "GIAA", "GJTL", | |
| "GLOB", "GLVA", "GMFI", "GMTD", "GOLD", "GOLF", "GOLL", "GOOD", "GOTO", "GPRA", | |
| "GPSO", "GRIA", "GRPH", "GRPM", "GSMF", "GTBO", "GTRA", "GTSI", "GULA", "GUNA", | |
| "GWSA", "GZCO", "HADE", "HAIS", "HAJJ", "HALO", "HATM", "HBAT", "HDFA", "HDIT", | |
| "HEAL", "HELI", "HERO", "HEXA", "HGII", "HILL", "HITS", "HKMU", "HMSP", "HOKI", | |
| "HOME", "HOMI", "HOPE", "HOTL", "HRME", "HRTA", "HRUM", "HUMI", "HYGN", "IATA", | |
| "IBFN", "IBOS", "IBST", "ICBP", "ICON", "IDEA", "IDPR", "IFII", "IFSH", "IGAR", | |
| "IIKP", "IKAI", "IKAN", "IKBI", "IKPM", "IMAS", "IMJS", "IMPC", "INAF", "INAI", | |
| "INCF", "INCI", "INCO", "INDF", "INDO", "INDR", "INDS", "INDX", "INDY", "INET", | |
| "INKP", "INOV", "INPC", "INPP", "INPS", "INRU", "INTA", "INTD", "INTP", "IOTF", | |
| "IPAC", "IPCC", "IPCM", "IPOL", "IPPE", "IPTV", "IRRA", "IRSX", "ISAP", "ISAT", | |
| "ISEA", "ISSP", "ITIC", "ITMA", "ITMG", "JARR", "JAST", "JATI", "JAWA", "JAYA", | |
| "JECC", "JGLE", "JIHD", "JKON", "JMAS", "JPFA", "JRPT", "JSKY", "JSMR", "JSPT", | |
| "JTPE", "KAEF", "KAQI", "KARW", "KAYU", "KBAG", "KBLI", "KBLM", "KBLV", "KBRI", | |
| "KDSI", "KDTN", "KEEN", "KEJU", "KETR", "KIAS", "KICI", "KIJA", "KING", "KINO", | |
| "KIOS", "KJEN", "KKES", "KKGI", "KLAS", "KLBF", "KLIN", "KMDS", "KMTR", "KOBX", | |
| "KOCI", "KOIN", "KOKA", "KONI", "KOPI", "KOTA", "KPIG", "KRAS", "KREN", "KRYA", | |
| "KSIX", "KUAS", "LABA", "LABS", "LAJU", "LAND", "LAPD", "LCGP", "LCKM", "LEAD", | |
| "LFLO", "LIFE", "LINK", "LION", "LIVE", "LMAS", "LMAX", "LMPI", "LMSH", "LOPI", | |
| "LPCK", "LPGI", "LPIN", "LPKR", "LPLI", "LPPF", "LPPS", "LRNA", "LSIP", "LTLS", | |
| "LUCK", "LUCY", "MABA", "MAGP", "MAHA", "MAIN", "MANG", "MAPA", "MAPB", "MAPI", | |
| "MARI", "MARK", "MASB", "MAXI", "MAYA", "MBAP", "MBMA", "MBSS", "MBTO", "MCAS", | |
| "MCOL", "MCOR", "MDIA", "MDIY", "MDKA", "MDKI", "MDLA", "MDLN", "MDRN", "MEDC", | |
| "MEDS", "MEGA", "MEJA", "MENN", "MERI", "MERK", "META", "MFMI", "MGLV", "MGNA", | |
| "MGRO", "MHKI", "MICE", "MIDI", "MIKA", "MINA", "MINE", "MIRA", "MITI", "MKAP", | |
| "MKNT", "MKPI", "MKTR", "MLBI", "MLIA", "MLPL", "MLPT", "MMIX", "MMLP", "MNCN", | |
| "MOLI", "MORA", "MPIX", "MPMX", "MPOW", "MPPA", "MPRO", "MPXL", "MRAT", "MREI", | |
| "MSIE", "MSIN", "MSJA", "MSKY", "MSTI", "MTDL", "MTEL", "MTFN", "MTLA", "MTMH", | |
| "MTPS", "MTRA", "MTSM", "MTWI", "MUTU", "MYOH", "MYOR", "MYTX", "NAIK", "NANO", | |
| "NASA", "NASI", "NATO", "NAYZ", "NCKL", "NELY", "NEST", "NETV", "NFCX", "NICE", | |
| "NICK", "NICL", "NIKL", "NINE", "NIRO", "NISP", "NOBU", "NPGF", "NRCA", "NSSS", | |
| "NTBK", "NUSA", "NZIA", "OASA", "OBAT", "OBMD", "OCAP", "OILS", "OKAS", "OLIV", | |
| "OMED", "OMRE", "OPMS", "PACK", "PADA", "PADI", "PALM", "PAMG", "PANI", "PANR", | |
| "PANS", "PART", "PBID", "PBRX", "PBSA", "PCAR", "PDES", "PDPP", "PEGE", "PEHA", | |
| "PEVE", "PGAS", "PGEO", "PGJO", "PGLI", "PGUN", "PICO", "PIPA", "PJAA", "PJHB", | |
| "PKPK", "PLAN", "PLAS", "PLIN", "PMJS", "PMMP", "PMUI", "PNBN", "PNBS", "PNGO", | |
| "PNIN", "PNLF", "PNSE", "POLA", "POLI", "POLL", "POLU", "POLY", "POOL", "PORT", | |
| "POSA", "POWR", "PPGL", "PPRE", "PPRI", "PPRO", "PRAY", "PRDA", "PRIM", "PSAB", | |
| "PSAT", "PSDN", "PSGO", "PSKT", "PSSI", "PTBA", "PTDU", "PTIS", "PTMP", "PTMR", | |
| "PTPP", "PTPS", "PTPW", "PTRO", "PTSN", "PTSP", "PUDP", "PURA", "PURE", "PURI", | |
| "PWON", "PYFA", "PZZA", "RAAM", "RAFI", "RAJA", "RALS", "RANC", "RATU", "RBMS", | |
| "RCCC", "RDTX", "REAL", "RELF", "RELI", "RGAS", "RICY", "RIGS", "RIMO", "RISE", | |
| "RLCO", "RMKE", "RMKO", "ROCK", "RODA", "RONY", "ROTI", "RSCH", "RSGK", "RUIS", | |
| "RUNS", "SAFE", "SAGE", "SAME", "SAMF", "SAPX", "SATU", "SBAT", "SBMA", "SCCO", | |
| "SCMA", "SCNP", "SCPI", "SDMU", "SDPC", "SDRA", "SEMA", "SFAN", "SGER", "SGRO", | |
| "SHID", "SHIP", "SICO", "SIDO", "SILO", "SIMA", "SIMP", "SINI", "SIPD", "SKBM", | |
| "SKLT", "SKRN", "SKYB", "SLIS", "SMAR", "SMBR", "SMCB", "SMDM", "SMDR", "SMGA", | |
| "SMGR", "SMIL", "SMKL", "SMKM", "SMLE", "SMMA", "SMMT", "SMRA", "SMRU", "SMSM", | |
| "SNLK", "SOCI", "SOFA", "SOHO", "SOLA", "SONA", "SOSS", "SOTS", "SOUL", "SPMA", | |
| "SPRE", "SPTO", "SQMI", "SRAJ", "SRIL", "SRSN", "SRTG", "SSIA", "SSMS", "SSTM", | |
| "STAA", "STAR", "STRK", "STTP", "SUGI", "SULI", "SUNI", "SUPA", "SUPR", "SURE", | |
| "SURI", "SWAT", "SWID", "TALF", "TAMA", "TAMU", "TAPG", "TARA", "TAXI", "TAYS", | |
| "TBIG", "TBLA", "TBMS", "TCID", "TCPI", "TDPM", "TEBE", "TECH", "TELE", "TFAS", | |
| "TFCO", "TGKA", "TGRA", "TGUK", "TIFA", "TINS", "TIRA", "TIRT", "TKIM", "TLDN", | |
| "TLKM", "TMAS", "TMPO", "TNCA", "TOBA", "TOOL", "TOPS", "TOSK", "TOTL", "TOTO", | |
| "TOWR", "TOYS", "TPIA", "TPMA", "TRAM", "TRGU", "TRIL", "TRIM", "TRIN", "TRIO", | |
| "TRIS", "TRJA", "TRON", "TRST", "TRUE", "TRUK", "TRUS", "TSPC", "TUGU", "TYRE", | |
| "UANG", "UCID", "UDNG", "UFOE", "ULTJ", "UNIC", "UNIQ", "UNIT", "UNSP", "UNTD", | |
| "UNTR", "UNVR", "URBN", "UVCR", "VAST", "VERN", "VICI", "VICO", "VINS", "VISI", | |
| "VIVA", "VKTR", "VOKS", "VRNA", "VTNY", "WAPO", "WEGE", "WEHA", "WGSH", "WICO", | |
| "WIDI", "WIFI", "WIIM", "WIKA", "WINE", "WINR", "WINS", "WIRG", "WMPP", "WMUU", | |
| "WOMF", "WOOD", "WOWS", "WSBP", "WSKT", "WTON", "YELO", "YOII", "YPAS", "YULE", | |
| "YUPI", "ZATA", "ZBRA", "ZINC", "ZONE", "ZYRX", | |
| ] | |
| INDOPREMIER_URL = "https://www.indopremier.com/module/saham/include/json-charting.php" | |
| _session = requests.Session() | |
| _session.headers.update({ | |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Chrome/124.0.0.0 Safari/537.36", | |
| "Accept": "application/json, text/plain, */*", | |
| "Referer": "https://www.indopremier.com/", | |
| }) | |
| # ββ Data fetching βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fetch_ohlcv(symbol: str, years: int = 5) -> Optional[pd.DataFrame]: | |
| """Fetch OHLCV from IndoPremier, return DataFrame with date/open/high/low/close/volume.""" | |
| end = datetime.now() | |
| start = end - timedelta(days=years * 365) | |
| fmt = lambda d: d.strftime("%m/%d/%Y") | |
| try: | |
| resp = _session.get( | |
| INDOPREMIER_URL, | |
| params={"code": symbol, "start": fmt(start), "end": fmt(end)}, | |
| timeout=15, | |
| ) | |
| resp.raise_for_status() | |
| raw = resp.json() | |
| if not isinstance(raw, list) or len(raw) < 60: | |
| return None | |
| df = pd.DataFrame(raw, columns=["timestamp_ms", "open", "high", "low", "close", "volume"]) | |
| df["date"] = pd.to_datetime(df["timestamp_ms"], unit="ms", utc=True).dt.tz_localize(None) | |
| df = df[df["close"] > 0].drop(columns=["timestamp_ms"]) | |
| df = df.sort_values("date").reset_index(drop=True) | |
| return df if len(df) >= 60 else None | |
| except Exception as e: | |
| return None | |
| # ββ Feature engineering βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def add_features(df: pd.DataFrame) -> pd.DataFrame: | |
| """Add normalized technical indicators as columns.""" | |
| c = df["close"].values.astype(np.float64) | |
| v = df["volume"].values.astype(np.float64) | |
| # Rolling 30-day z-score normalization for price | |
| s = pd.Series(c) | |
| rm = s.rolling(30, min_periods=1).mean() | |
| rs = s.rolling(30, min_periods=1).std(ddof=0).fillna(1).clip(lower=1e-6) | |
| df["close_norm"] = ((s - rm) / rs).values | |
| # Volume norm | |
| sv = pd.Series(v) | |
| vm = sv.rolling(30, min_periods=1).mean() | |
| vs = sv.rolling(30, min_periods=1).std(ddof=0).fillna(1).clip(lower=1e-6) | |
| df["volume_norm"] = ((sv - vm) / vs).values | |
| # RSI (normalized 0-1) | |
| delta = np.diff(c, prepend=c[0]) | |
| gain = pd.Series(np.where(delta > 0, delta, 0.0)).ewm(alpha=1/14, adjust=False).mean().values | |
| loss = pd.Series(np.where(delta < 0, -delta, 0.0)).ewm(alpha=1/14, adjust=False).mean().values | |
| rs_ratio = np.where(loss == 0, 100.0, gain / (loss + 1e-9)) | |
| df["rsi"] = np.clip(rs_ratio / (1 + rs_ratio), 0, 1) | |
| # MACD histogram (normalized) | |
| sp = pd.Series(c) | |
| macd_line = sp.ewm(span=12, adjust=False).mean() - sp.ewm(span=26, adjust=False).mean() | |
| signal = macd_line.ewm(span=9, adjust=False).mean() | |
| macd_raw = (macd_line - signal).values | |
| df["macd_norm"] = macd_raw / (np.std(macd_raw) or 1) | |
| # Bollinger band width | |
| sma = sp.rolling(20, min_periods=1).mean() | |
| std20 = sp.rolling(20, min_periods=1).std(ddof=0).fillna(0) | |
| df["bb_width"] = (2 * std20 / sma.clip(lower=1e-6)).fillna(0).values | |
| # ATR normalized | |
| h, l_col = df["high"].values, df["low"].values | |
| prev_c = np.roll(c, 1); prev_c[0] = c[0] | |
| tr = np.maximum(h - l_col, np.maximum(np.abs(h - prev_c), np.abs(l_col - prev_c))) | |
| atr = pd.Series(tr).ewm(alpha=1/14, adjust=False).mean().values | |
| df["atr_norm"] = atr / (c + 1e-9) | |
| # OBV normalized | |
| direction = np.sign(np.diff(c, prepend=c[0])) | |
| obv = np.cumsum(direction * v) | |
| obv_std = np.std(obv) or 1 | |
| df["obv_norm"] = (obv - np.mean(obv)) / obv_std | |
| # Cyclical encodings | |
| df["day_sin"] = np.sin(2 * np.pi * df["date"].dt.dayofweek / 5) | |
| df["day_cos"] = np.cos(2 * np.pi * df["date"].dt.dayofweek / 5) | |
| df["month_sin"] = np.sin(2 * np.pi * df["date"].dt.month / 12) | |
| df["month_cos"] = np.cos(2 * np.pi * df["date"].dt.month / 12) | |
| return df.dropna().reset_index(drop=True) | |
| # ββ Dataset assembly ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| KNOWN_REALS = ["day_sin", "day_cos", "month_sin", "month_cos"] | |
| UNKNOWN_REALS = ["close_norm", "volume_norm", "rsi", "macd_norm", "bb_width", "atr_norm", "obv_norm"] | |
| TARGET = "close_norm" | |
| def collect_data( | |
| tickers: list[str], | |
| years: int = 5, | |
| delay: float = 0.2, | |
| max_failures: int = 400, | |
| ) -> pd.DataFrame: | |
| """Fetch and concatenate all tickers into a single DataFrame for TimeSeriesDataSet.""" | |
| all_dfs = [] | |
| failures = 0 | |
| print(f"Fetching {len(tickers)} tickers from IndoPremier ({years}y history)...") | |
| for i, ticker in enumerate(tickers): | |
| df = fetch_ohlcv(ticker, years=years) | |
| if df is None: | |
| failures += 1 | |
| if (i + 1) % 50 == 0: | |
| print(f" [{i+1}/{len(tickers)}] {len(all_dfs)} valid, {failures} failed") | |
| if failures >= max_failures: | |
| print(f" Too many failures ({failures}), stopping early.") | |
| break | |
| time.sleep(delay) | |
| continue | |
| df = add_features(df) | |
| df["ticker"] = ticker | |
| df["time_idx"] = np.arange(len(df)) | |
| all_dfs.append(df[["ticker", "time_idx", "date", TARGET] + KNOWN_REALS + UNKNOWN_REALS]) | |
| if (i + 1) % 50 == 0: | |
| print(f" [{i+1}/{len(tickers)}] {len(all_dfs)} valid, {failures} failed") | |
| time.sleep(delay) | |
| if not all_dfs: | |
| raise RuntimeError("No data collected from IndoPremier.") | |
| combined = pd.concat(all_dfs, ignore_index=True) | |
| print(f"\nTotal rows: {len(combined):,} across {len(all_dfs)} tickers") | |
| return combined | |
| # ββ pytorch-forecasting DataSet + Model βββββββββββββββββββββββββββββββββββββββ | |
| def build_dataset(df: pd.DataFrame, predict: bool = False) -> TimeSeriesDataSet: | |
| """Build pytorch-forecasting TimeSeriesDataSet.""" | |
| # Use last PREDICTION_LENGTH rows of each ticker for validation | |
| training_cutoff = df.groupby("ticker")["time_idx"].transform( | |
| lambda x: x.max() - PREDICTION_LENGTH | |
| ) | |
| subset = df[df["time_idx"] <= training_cutoff] if not predict else df | |
| return TimeSeriesDataSet( | |
| subset, | |
| time_idx="time_idx", | |
| target=TARGET, | |
| group_ids=["ticker"], | |
| min_encoder_length=ENCODER_LENGTH // 2, | |
| max_encoder_length=ENCODER_LENGTH, | |
| min_prediction_length=1, | |
| max_prediction_length=PREDICTION_LENGTH, | |
| static_categoricals=["ticker"], | |
| time_varying_known_reals=KNOWN_REALS + ["time_idx"], | |
| time_varying_unknown_reals=UNKNOWN_REALS, | |
| target_normalizer=None, # already normalized | |
| add_relative_time_idx=True, | |
| add_target_scales=True, | |
| add_encoder_length=True, | |
| allow_missing_timesteps=True, | |
| ) | |
| def build_tft(training_dataset: TimeSeriesDataSet) -> TemporalFusionTransformer: | |
| return TemporalFusionTransformer.from_dataset( | |
| training_dataset, | |
| learning_rate=LR, | |
| hidden_size=HIDDEN_SIZE, | |
| attention_head_size=ATTENTION_HEAD_SIZE, | |
| dropout=DROPOUT, | |
| hidden_continuous_size=HIDDEN_CONTINUOUS_SIZE, | |
| loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]), | |
| log_interval=10, | |
| reduce_on_plateau_patience=4, | |
| optimizer="adam", | |
| ) | |
| # ββ DDG-DA drift predictor training (reused from train.py logic) ββββββββββββββ | |
| def train_ddg_da_predictor(df: pd.DataFrame) -> None: | |
| """Train the lightweight MLP drift predictor for DDG-DA.""" | |
| import torch.nn as nn | |
| from torch.utils.data import TensorDataset, DataLoader | |
| SNAPSHOT_WINDOW = 20 | |
| K_HISTORY = 8 | |
| N_FEATURES = len(UNKNOWN_REALS) + len(KNOWN_REALS) # 11 | |
| SNAPSHOT_DIM = N_FEATURES * 4 # 44 | |
| def compute_snapshot(arr: np.ndarray) -> np.ndarray: | |
| stats = [] | |
| for f in range(arr.shape[1]): | |
| col = arr[:, f].astype(np.float64) | |
| mu = col.mean(); sigma = col.std() or 1e-8 | |
| skew = float(((col - mu) ** 3).mean() / sigma ** 3) | |
| kurt = float(((col - mu) ** 4).mean() / sigma ** 4) - 3.0 | |
| stats.extend([float(mu), float(sigma), skew, kurt]) | |
| return np.array(stats, dtype=np.float32) | |
| feature_cols = UNKNOWN_REALS + KNOWN_REALS | |
| all_X, all_y = [], [] | |
| for ticker, grp in df.groupby("ticker"): | |
| feat = grp[feature_cols].values.astype(np.float32) | |
| k = len(feat) // SNAPSHOT_WINDOW | |
| if k < K_HISTORY + 1: | |
| continue | |
| snaps = np.stack([compute_snapshot(feat[i*SNAPSHOT_WINDOW:(i+1)*SNAPSHOT_WINDOW]) for i in range(k)]) | |
| for i in range(len(snaps) - K_HISTORY): | |
| all_X.append(snaps[i:i+K_HISTORY].flatten()) | |
| all_y.append(snaps[i+K_HISTORY]) | |
| if not all_X: | |
| print("Not enough data for DDG-DA training, skipping.") | |
| return | |
| X = np.array(all_X, dtype=np.float32) | |
| y = np.array(all_y, dtype=np.float32) | |
| idx = np.random.permutation(len(X)) | |
| X, y = X[idx], y[idx] | |
| split = int(len(X) * 0.9) | |
| train_ds = TensorDataset(torch.tensor(X[:split]), torch.tensor(y[:split])) | |
| val_ds = TensorDataset(torch.tensor(X[split:]), torch.tensor(y[split:])) | |
| train_dl = DataLoader(train_ds, batch_size=256, shuffle=True) | |
| val_dl = DataLoader(val_ds, batch_size=256) | |
| mlp = nn.Sequential( | |
| nn.Linear(K_HISTORY * SNAPSHOT_DIM, 128), | |
| nn.ELU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(128, SNAPSHOT_DIM), | |
| ) | |
| opt = torch.optim.Adam(mlp.parameters(), lr=1e-3) | |
| crit = nn.MSELoss() | |
| best_val = float("inf") | |
| print(f"\n--- Training DDG-DA MLP ({len(X):,} pairs) ---") | |
| for epoch in range(1, 31): | |
| mlp.train() | |
| tl = sum(crit(mlp(xb), yb).item() * len(xb) for xb, yb in train_dl) / len(X[:split]) | |
| mlp.eval() | |
| with torch.no_grad(): | |
| vl = sum(crit(mlp(xb), yb).item() * len(xb) for xb, yb in val_dl) / len(X[split:]) | |
| print(f" Epoch {epoch:2d}/30 train={tl:.6f} val={vl:.6f}") | |
| if vl < best_val: | |
| best_val = vl | |
| torch.save(mlp.state_dict(), DDG_DA_PATH) | |
| print(f"DDG-DA saved β {DDG_DA_PATH}") | |
| # ββ Upload to HF Hub ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def upload_to_hub(checkpoint_path: str) -> None: | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| hf_repo = os.environ.get("HF_MODEL_REPO", "") | |
| if not hf_token or not hf_repo: | |
| print("HF_TOKEN / HF_MODEL_REPO not set β skipping upload.") | |
| return | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=hf_token) | |
| api.create_repo(repo_id=hf_repo, repo_type="model", exist_ok=True, private=True) | |
| for local, remote in [ | |
| (checkpoint_path, "tft_stock.ckpt"), | |
| (DATASET_PARAMS_PATH, "dataset_params.pt"), | |
| (DDG_DA_PATH, "ddg_da.pt"), | |
| ]: | |
| if os.path.exists(local): | |
| api.upload_file(path_or_fileobj=local, path_in_repo=remote, repo_id=hf_repo, repo_type="model", | |
| commit_message=f"Colab retrain: {remote}") | |
| print(f"Uploaded {remote} β {hf_repo}") | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| print("=" * 60) | |
| print("StockPro TFT Training (pytorch-forecasting)") | |
| print("=" * 60) | |
| # 1. Collect data | |
| df = collect_data(IDX_TICKERS, years=5) | |
| df.to_parquet(os.path.join(MODEL_DIR, "training_data.parquet"), index=False) | |
| print(f"Data saved β {MODEL_DIR}/training_data.parquet") | |
| # 2. Build datasets | |
| training_ds = build_dataset(df, predict=False) | |
| # Save dataset parameters (fitted categorical encoders, scalers, config) | |
| # These are required by the HF Spaces inference server at runtime. | |
| torch.save(training_ds.get_parameters(), DATASET_PARAMS_PATH) | |
| print(f"Dataset params saved β {DATASET_PARAMS_PATH}") | |
| val_ds = training_ds.from_dataset(training_ds, df, predict=True, stop_randomization=True) | |
| train_dl = training_ds.to_dataloader(train=True, batch_size=BATCH_SIZE, num_workers=2) | |
| val_dl = val_ds.to_dataloader(train=False, batch_size=BATCH_SIZE * 2, num_workers=2) | |
| # 3. Build TFT model | |
| tft = build_tft(training_ds) | |
| print(f"\nTFT parameters: {sum(p.numel() for p in tft.parameters()):,}") | |
| # 4. Train with Lightning | |
| from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint | |
| callbacks = [ | |
| EarlyStopping(monitor="val_loss", patience=5, mode="min"), | |
| ModelCheckpoint( | |
| dirpath=CHECKPOINT_DIR, | |
| filename="tft-{epoch:02d}-{val_loss:.4f}", | |
| monitor="val_loss", | |
| mode="min", | |
| save_top_k=1, | |
| ), | |
| ] | |
| trainer = pl.Trainer( | |
| max_epochs=MAX_EPOCHS, | |
| accelerator="auto", # GPU if available | |
| gradient_clip_val=GRADIENT_CLIP, | |
| callbacks=callbacks, | |
| enable_progress_bar=True, | |
| log_every_n_steps=10, | |
| ) | |
| print("\n--- Training TFT ---") | |
| trainer.fit(tft, train_dataloaders=train_dl, val_dataloaders=val_dl) | |
| # 5. Copy best checkpoint | |
| best_ckpt = callbacks[1].best_model_path | |
| if best_ckpt and os.path.exists(best_ckpt): | |
| import shutil | |
| shutil.copy2(best_ckpt, FINAL_MODEL_PATH) | |
| print(f"\nBest model β {FINAL_MODEL_PATH}") | |
| # 6. Train DDG-DA | |
| train_ddg_da_predictor(df) | |
| # 7. Upload to HF Hub | |
| upload_to_hub(FINAL_MODEL_PATH) | |
| print("\nDone!") | |
| if __name__ == "__main__": | |
| main() | |