stockpro-ml / scripts /train_colab.py
will702's picture
StockPro ML backend with pytorch-forecasting TFT
9334ec6
"""
StockPro TFT Training Script β€” Google Colab Edition
=====================================================
Uses pytorch-forecasting's production TFT implementation.
Usage in Colab:
# 1. Install deps
!pip install pytorch-forecasting pytorch-lightning huggingface_hub requests pandas numpy
# 2. Mount Drive (optional, for checkpoint persistence)
from google.colab import drive
drive.mount('/content/drive')
# 3. Run
!python train_colab.py
# Or set env vars for HF Hub upload:
import os
os.environ["HF_TOKEN"] = "hf_xxx"
os.environ["HF_MODEL_REPO"] = "username/stockpro-tft"
!python train_colab.py
"""
import os
import sys
import time
import requests
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import Optional
# ── Install check ─────────────────────────────────────────────────────────────
try:
import pytorch_forecasting # noqa: F401
import lightning.pytorch as pl
except ImportError:
print("Installing pytorch-forecasting and lightning...")
os.system("pip install -q pytorch-forecasting pytorch-lightning")
import pytorch_forecasting # noqa: F401
import lightning.pytorch as pl
import torch
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.metrics import QuantileLoss
from torch.utils.data import DataLoader
# ── Config ────────────────────────────────────────────────────────────────────
ENCODER_LENGTH = 60 # lookback window
PREDICTION_LENGTH = 30 # max forecast horizon
BATCH_SIZE = 64
MAX_EPOCHS = 50
LR = 1e-3
GRADIENT_CLIP = 0.1
HIDDEN_SIZE = 64 # TFT hidden state
ATTENTION_HEAD_SIZE = 4
DROPOUT = 0.1
HIDDEN_CONTINUOUS_SIZE = 16
# Save paths β€” adjust for Colab/Drive as needed
MODEL_DIR = os.environ.get("MODEL_DIR", "/content/models")
CHECKPOINT_DIR = os.path.join(MODEL_DIR, "checkpoints")
FINAL_MODEL_PATH = os.path.join(MODEL_DIR, "tft_stock.ckpt")
DATASET_PARAMS_PATH = os.path.join(MODEL_DIR, "dataset_params.pt")
DDG_DA_PATH = os.path.join(MODEL_DIR, "ddg_da.pt")
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# ── Full IDX ticker list (956 stocks, source: IDX Daftar Saham 2026-03-08) ────
IDX_TICKERS = [
"AADI", "AALI", "ABBA", "ABDA", "ABMM", "ACES", "ACRO", "ACST", "ADCP", "ADES",
"ADHI", "ADMF", "ADMG", "ADMR", "ADRO", "AEGS", "AGAR", "AGII", "AGRO", "AGRS",
"AHAP", "AIMS", "AISA", "AKKU", "AKPI", "AKRA", "AKSI", "ALDO", "ALII", "ALKA",
"ALMI", "ALTO", "AMAG", "AMAN", "AMAR", "AMFG", "AMIN", "AMMN", "AMMS", "AMOR",
"AMRT", "ANDI", "ANJT", "ANTM", "APEX", "APIC", "APII", "APLI", "APLN", "ARCI",
"AREA", "ARGO", "ARII", "ARKA", "ARKO", "ARMY", "ARNA", "ARTA", "ARTI", "ARTO",
"ASBI", "ASDM", "ASGR", "ASHA", "ASII", "ASJT", "ASLC", "ASLI", "ASMI", "ASPI",
"ASPR", "ASRI", "ASRM", "ASSA", "ATAP", "ATIC", "ATLA", "AUTO", "AVIA", "AWAN",
"AXIO", "AYAM", "AYLS", "BABP", "BABY", "BACA", "BAIK", "BAJA", "BALI", "BANK",
"BAPA", "BAPI", "BATA", "BATR", "BAUT", "BAYU", "BBCA", "BBHI", "BBKP", "BBLD",
"BBMD", "BBNI", "BBRI", "BBRM", "BBSI", "BBSS", "BBTN", "BBYB", "BCAP", "BCIC",
"BCIP", "BDKR", "BDMN", "BEBS", "BEEF", "BEER", "BEKS", "BELI", "BELL", "BESS",
"BEST", "BFIN", "BGTG", "BHAT", "BHIT", "BIKA", "BIKE", "BIMA", "BINA", "BINO",
"BIPI", "BIPP", "BIRD", "BISI", "BJBR", "BJTM", "BKDP", "BKSL", "BKSW", "BLES",
"BLOG", "BLTA", "BLTZ", "BLUE", "BMAS", "BMBL", "BMHS", "BMRI", "BMSR", "BMTR",
"BNBA", "BNBR", "BNGA", "BNII", "BNLI", "BOAT", "BOBA", "BOGA", "BOLA", "BOLT",
"BOSS", "BPFI", "BPII", "BPTR", "BRAM", "BREN", "BRIS", "BRMS", "BRNA", "BRPT",
"BRRC", "BSBK", "BSDE", "BSIM", "BSML", "BSSR", "BSWD", "BTEK", "BTEL", "BTON",
"BTPN", "BTPS", "BUAH", "BUDI", "BUKA", "BUKK", "BULL", "BUMI", "BUVA", "BVIC",
"BWPT", "BYAN", "CAKK", "CAMP", "CANI", "CARE", "CARS", "CASA", "CASH", "CASS",
"CBDK", "CBMF", "CBPE", "CBRE", "CBUT", "CCSI", "CDIA", "CEKA", "CENT", "CFIN",
"CGAS", "CHEK", "CHEM", "CHIP", "CINT", "CITA", "CITY", "CLAY", "CLEO", "CLPI",
"CMNP", "CMNT", "CMPP", "CMRY", "CNKO", "CNMA", "CNTX", "COAL", "COCO", "COIN",
"COWL", "CPIN", "CPRI", "CPRO", "CRAB", "CRSN", "CSAP", "CSIS", "CSMI", "CSRA",
"CTBN", "CTRA", "CTTH", "CUAN", "CYBR", "DAAZ", "DADA", "DART", "DATA", "DAYA",
"DCII", "DEAL", "DEFI", "DEPO", "DEWA", "DEWI", "DFAM", "DGIK", "DGNS", "DGWG",
"DIGI", "DILD", "DIVA", "DKFT", "DKHH", "DLTA", "DMAS", "DMMX", "DMND", "DNAR",
"DNET", "DOID", "DOOH", "DOSS", "DPNS", "DPUM", "DRMA", "DSFI", "DSNG", "DSSA",
"DUCK", "DUTI", "DVLA", "DWGL", "DYAN", "EAST", "ECII", "EDGE", "EKAD", "ELIT",
"ELPI", "ELSA", "ELTY", "EMAS", "EMDE", "EMTK", "ENAK", "ENRG", "ENVY", "ENZO",
"EPAC", "EPMT", "ERAA", "ERAL", "ERTX", "ESIP", "ESSA", "ESTA", "ESTI", "ETWA",
"EURO", "EXCL", "FAPA", "FAST", "FASW", "FILM", "FIMP", "FIRE", "FISH", "FITT",
"FLMC", "FMII", "FOLK", "FOOD", "FORE", "FORU", "FPNI", "FUJI", "FUTR", "FWCT",
"GAMA", "GDST", "GDYR", "GEMA", "GEMS", "GGRM", "GGRP", "GHON", "GIAA", "GJTL",
"GLOB", "GLVA", "GMFI", "GMTD", "GOLD", "GOLF", "GOLL", "GOOD", "GOTO", "GPRA",
"GPSO", "GRIA", "GRPH", "GRPM", "GSMF", "GTBO", "GTRA", "GTSI", "GULA", "GUNA",
"GWSA", "GZCO", "HADE", "HAIS", "HAJJ", "HALO", "HATM", "HBAT", "HDFA", "HDIT",
"HEAL", "HELI", "HERO", "HEXA", "HGII", "HILL", "HITS", "HKMU", "HMSP", "HOKI",
"HOME", "HOMI", "HOPE", "HOTL", "HRME", "HRTA", "HRUM", "HUMI", "HYGN", "IATA",
"IBFN", "IBOS", "IBST", "ICBP", "ICON", "IDEA", "IDPR", "IFII", "IFSH", "IGAR",
"IIKP", "IKAI", "IKAN", "IKBI", "IKPM", "IMAS", "IMJS", "IMPC", "INAF", "INAI",
"INCF", "INCI", "INCO", "INDF", "INDO", "INDR", "INDS", "INDX", "INDY", "INET",
"INKP", "INOV", "INPC", "INPP", "INPS", "INRU", "INTA", "INTD", "INTP", "IOTF",
"IPAC", "IPCC", "IPCM", "IPOL", "IPPE", "IPTV", "IRRA", "IRSX", "ISAP", "ISAT",
"ISEA", "ISSP", "ITIC", "ITMA", "ITMG", "JARR", "JAST", "JATI", "JAWA", "JAYA",
"JECC", "JGLE", "JIHD", "JKON", "JMAS", "JPFA", "JRPT", "JSKY", "JSMR", "JSPT",
"JTPE", "KAEF", "KAQI", "KARW", "KAYU", "KBAG", "KBLI", "KBLM", "KBLV", "KBRI",
"KDSI", "KDTN", "KEEN", "KEJU", "KETR", "KIAS", "KICI", "KIJA", "KING", "KINO",
"KIOS", "KJEN", "KKES", "KKGI", "KLAS", "KLBF", "KLIN", "KMDS", "KMTR", "KOBX",
"KOCI", "KOIN", "KOKA", "KONI", "KOPI", "KOTA", "KPIG", "KRAS", "KREN", "KRYA",
"KSIX", "KUAS", "LABA", "LABS", "LAJU", "LAND", "LAPD", "LCGP", "LCKM", "LEAD",
"LFLO", "LIFE", "LINK", "LION", "LIVE", "LMAS", "LMAX", "LMPI", "LMSH", "LOPI",
"LPCK", "LPGI", "LPIN", "LPKR", "LPLI", "LPPF", "LPPS", "LRNA", "LSIP", "LTLS",
"LUCK", "LUCY", "MABA", "MAGP", "MAHA", "MAIN", "MANG", "MAPA", "MAPB", "MAPI",
"MARI", "MARK", "MASB", "MAXI", "MAYA", "MBAP", "MBMA", "MBSS", "MBTO", "MCAS",
"MCOL", "MCOR", "MDIA", "MDIY", "MDKA", "MDKI", "MDLA", "MDLN", "MDRN", "MEDC",
"MEDS", "MEGA", "MEJA", "MENN", "MERI", "MERK", "META", "MFMI", "MGLV", "MGNA",
"MGRO", "MHKI", "MICE", "MIDI", "MIKA", "MINA", "MINE", "MIRA", "MITI", "MKAP",
"MKNT", "MKPI", "MKTR", "MLBI", "MLIA", "MLPL", "MLPT", "MMIX", "MMLP", "MNCN",
"MOLI", "MORA", "MPIX", "MPMX", "MPOW", "MPPA", "MPRO", "MPXL", "MRAT", "MREI",
"MSIE", "MSIN", "MSJA", "MSKY", "MSTI", "MTDL", "MTEL", "MTFN", "MTLA", "MTMH",
"MTPS", "MTRA", "MTSM", "MTWI", "MUTU", "MYOH", "MYOR", "MYTX", "NAIK", "NANO",
"NASA", "NASI", "NATO", "NAYZ", "NCKL", "NELY", "NEST", "NETV", "NFCX", "NICE",
"NICK", "NICL", "NIKL", "NINE", "NIRO", "NISP", "NOBU", "NPGF", "NRCA", "NSSS",
"NTBK", "NUSA", "NZIA", "OASA", "OBAT", "OBMD", "OCAP", "OILS", "OKAS", "OLIV",
"OMED", "OMRE", "OPMS", "PACK", "PADA", "PADI", "PALM", "PAMG", "PANI", "PANR",
"PANS", "PART", "PBID", "PBRX", "PBSA", "PCAR", "PDES", "PDPP", "PEGE", "PEHA",
"PEVE", "PGAS", "PGEO", "PGJO", "PGLI", "PGUN", "PICO", "PIPA", "PJAA", "PJHB",
"PKPK", "PLAN", "PLAS", "PLIN", "PMJS", "PMMP", "PMUI", "PNBN", "PNBS", "PNGO",
"PNIN", "PNLF", "PNSE", "POLA", "POLI", "POLL", "POLU", "POLY", "POOL", "PORT",
"POSA", "POWR", "PPGL", "PPRE", "PPRI", "PPRO", "PRAY", "PRDA", "PRIM", "PSAB",
"PSAT", "PSDN", "PSGO", "PSKT", "PSSI", "PTBA", "PTDU", "PTIS", "PTMP", "PTMR",
"PTPP", "PTPS", "PTPW", "PTRO", "PTSN", "PTSP", "PUDP", "PURA", "PURE", "PURI",
"PWON", "PYFA", "PZZA", "RAAM", "RAFI", "RAJA", "RALS", "RANC", "RATU", "RBMS",
"RCCC", "RDTX", "REAL", "RELF", "RELI", "RGAS", "RICY", "RIGS", "RIMO", "RISE",
"RLCO", "RMKE", "RMKO", "ROCK", "RODA", "RONY", "ROTI", "RSCH", "RSGK", "RUIS",
"RUNS", "SAFE", "SAGE", "SAME", "SAMF", "SAPX", "SATU", "SBAT", "SBMA", "SCCO",
"SCMA", "SCNP", "SCPI", "SDMU", "SDPC", "SDRA", "SEMA", "SFAN", "SGER", "SGRO",
"SHID", "SHIP", "SICO", "SIDO", "SILO", "SIMA", "SIMP", "SINI", "SIPD", "SKBM",
"SKLT", "SKRN", "SKYB", "SLIS", "SMAR", "SMBR", "SMCB", "SMDM", "SMDR", "SMGA",
"SMGR", "SMIL", "SMKL", "SMKM", "SMLE", "SMMA", "SMMT", "SMRA", "SMRU", "SMSM",
"SNLK", "SOCI", "SOFA", "SOHO", "SOLA", "SONA", "SOSS", "SOTS", "SOUL", "SPMA",
"SPRE", "SPTO", "SQMI", "SRAJ", "SRIL", "SRSN", "SRTG", "SSIA", "SSMS", "SSTM",
"STAA", "STAR", "STRK", "STTP", "SUGI", "SULI", "SUNI", "SUPA", "SUPR", "SURE",
"SURI", "SWAT", "SWID", "TALF", "TAMA", "TAMU", "TAPG", "TARA", "TAXI", "TAYS",
"TBIG", "TBLA", "TBMS", "TCID", "TCPI", "TDPM", "TEBE", "TECH", "TELE", "TFAS",
"TFCO", "TGKA", "TGRA", "TGUK", "TIFA", "TINS", "TIRA", "TIRT", "TKIM", "TLDN",
"TLKM", "TMAS", "TMPO", "TNCA", "TOBA", "TOOL", "TOPS", "TOSK", "TOTL", "TOTO",
"TOWR", "TOYS", "TPIA", "TPMA", "TRAM", "TRGU", "TRIL", "TRIM", "TRIN", "TRIO",
"TRIS", "TRJA", "TRON", "TRST", "TRUE", "TRUK", "TRUS", "TSPC", "TUGU", "TYRE",
"UANG", "UCID", "UDNG", "UFOE", "ULTJ", "UNIC", "UNIQ", "UNIT", "UNSP", "UNTD",
"UNTR", "UNVR", "URBN", "UVCR", "VAST", "VERN", "VICI", "VICO", "VINS", "VISI",
"VIVA", "VKTR", "VOKS", "VRNA", "VTNY", "WAPO", "WEGE", "WEHA", "WGSH", "WICO",
"WIDI", "WIFI", "WIIM", "WIKA", "WINE", "WINR", "WINS", "WIRG", "WMPP", "WMUU",
"WOMF", "WOOD", "WOWS", "WSBP", "WSKT", "WTON", "YELO", "YOII", "YPAS", "YULE",
"YUPI", "ZATA", "ZBRA", "ZINC", "ZONE", "ZYRX",
]
INDOPREMIER_URL = "https://www.indopremier.com/module/saham/include/json-charting.php"
_session = requests.Session()
_session.headers.update({
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Chrome/124.0.0.0 Safari/537.36",
"Accept": "application/json, text/plain, */*",
"Referer": "https://www.indopremier.com/",
})
# ── Data fetching ─────────────────────────────────────────────────────────────
def fetch_ohlcv(symbol: str, years: int = 5) -> Optional[pd.DataFrame]:
"""Fetch OHLCV from IndoPremier, return DataFrame with date/open/high/low/close/volume."""
end = datetime.now()
start = end - timedelta(days=years * 365)
fmt = lambda d: d.strftime("%m/%d/%Y")
try:
resp = _session.get(
INDOPREMIER_URL,
params={"code": symbol, "start": fmt(start), "end": fmt(end)},
timeout=15,
)
resp.raise_for_status()
raw = resp.json()
if not isinstance(raw, list) or len(raw) < 60:
return None
df = pd.DataFrame(raw, columns=["timestamp_ms", "open", "high", "low", "close", "volume"])
df["date"] = pd.to_datetime(df["timestamp_ms"], unit="ms", utc=True).dt.tz_localize(None)
df = df[df["close"] > 0].drop(columns=["timestamp_ms"])
df = df.sort_values("date").reset_index(drop=True)
return df if len(df) >= 60 else None
except Exception as e:
return None
# ── Feature engineering ───────────────────────────────────────────────────────
def add_features(df: pd.DataFrame) -> pd.DataFrame:
"""Add normalized technical indicators as columns."""
c = df["close"].values.astype(np.float64)
v = df["volume"].values.astype(np.float64)
# Rolling 30-day z-score normalization for price
s = pd.Series(c)
rm = s.rolling(30, min_periods=1).mean()
rs = s.rolling(30, min_periods=1).std(ddof=0).fillna(1).clip(lower=1e-6)
df["close_norm"] = ((s - rm) / rs).values
# Volume norm
sv = pd.Series(v)
vm = sv.rolling(30, min_periods=1).mean()
vs = sv.rolling(30, min_periods=1).std(ddof=0).fillna(1).clip(lower=1e-6)
df["volume_norm"] = ((sv - vm) / vs).values
# RSI (normalized 0-1)
delta = np.diff(c, prepend=c[0])
gain = pd.Series(np.where(delta > 0, delta, 0.0)).ewm(alpha=1/14, adjust=False).mean().values
loss = pd.Series(np.where(delta < 0, -delta, 0.0)).ewm(alpha=1/14, adjust=False).mean().values
rs_ratio = np.where(loss == 0, 100.0, gain / (loss + 1e-9))
df["rsi"] = np.clip(rs_ratio / (1 + rs_ratio), 0, 1)
# MACD histogram (normalized)
sp = pd.Series(c)
macd_line = sp.ewm(span=12, adjust=False).mean() - sp.ewm(span=26, adjust=False).mean()
signal = macd_line.ewm(span=9, adjust=False).mean()
macd_raw = (macd_line - signal).values
df["macd_norm"] = macd_raw / (np.std(macd_raw) or 1)
# Bollinger band width
sma = sp.rolling(20, min_periods=1).mean()
std20 = sp.rolling(20, min_periods=1).std(ddof=0).fillna(0)
df["bb_width"] = (2 * std20 / sma.clip(lower=1e-6)).fillna(0).values
# ATR normalized
h, l_col = df["high"].values, df["low"].values
prev_c = np.roll(c, 1); prev_c[0] = c[0]
tr = np.maximum(h - l_col, np.maximum(np.abs(h - prev_c), np.abs(l_col - prev_c)))
atr = pd.Series(tr).ewm(alpha=1/14, adjust=False).mean().values
df["atr_norm"] = atr / (c + 1e-9)
# OBV normalized
direction = np.sign(np.diff(c, prepend=c[0]))
obv = np.cumsum(direction * v)
obv_std = np.std(obv) or 1
df["obv_norm"] = (obv - np.mean(obv)) / obv_std
# Cyclical encodings
df["day_sin"] = np.sin(2 * np.pi * df["date"].dt.dayofweek / 5)
df["day_cos"] = np.cos(2 * np.pi * df["date"].dt.dayofweek / 5)
df["month_sin"] = np.sin(2 * np.pi * df["date"].dt.month / 12)
df["month_cos"] = np.cos(2 * np.pi * df["date"].dt.month / 12)
return df.dropna().reset_index(drop=True)
# ── Dataset assembly ──────────────────────────────────────────────────────────
KNOWN_REALS = ["day_sin", "day_cos", "month_sin", "month_cos"]
UNKNOWN_REALS = ["close_norm", "volume_norm", "rsi", "macd_norm", "bb_width", "atr_norm", "obv_norm"]
TARGET = "close_norm"
def collect_data(
tickers: list[str],
years: int = 5,
delay: float = 0.2,
max_failures: int = 400,
) -> pd.DataFrame:
"""Fetch and concatenate all tickers into a single DataFrame for TimeSeriesDataSet."""
all_dfs = []
failures = 0
print(f"Fetching {len(tickers)} tickers from IndoPremier ({years}y history)...")
for i, ticker in enumerate(tickers):
df = fetch_ohlcv(ticker, years=years)
if df is None:
failures += 1
if (i + 1) % 50 == 0:
print(f" [{i+1}/{len(tickers)}] {len(all_dfs)} valid, {failures} failed")
if failures >= max_failures:
print(f" Too many failures ({failures}), stopping early.")
break
time.sleep(delay)
continue
df = add_features(df)
df["ticker"] = ticker
df["time_idx"] = np.arange(len(df))
all_dfs.append(df[["ticker", "time_idx", "date", TARGET] + KNOWN_REALS + UNKNOWN_REALS])
if (i + 1) % 50 == 0:
print(f" [{i+1}/{len(tickers)}] {len(all_dfs)} valid, {failures} failed")
time.sleep(delay)
if not all_dfs:
raise RuntimeError("No data collected from IndoPremier.")
combined = pd.concat(all_dfs, ignore_index=True)
print(f"\nTotal rows: {len(combined):,} across {len(all_dfs)} tickers")
return combined
# ── pytorch-forecasting DataSet + Model ───────────────────────────────────────
def build_dataset(df: pd.DataFrame, predict: bool = False) -> TimeSeriesDataSet:
"""Build pytorch-forecasting TimeSeriesDataSet."""
# Use last PREDICTION_LENGTH rows of each ticker for validation
training_cutoff = df.groupby("ticker")["time_idx"].transform(
lambda x: x.max() - PREDICTION_LENGTH
)
subset = df[df["time_idx"] <= training_cutoff] if not predict else df
return TimeSeriesDataSet(
subset,
time_idx="time_idx",
target=TARGET,
group_ids=["ticker"],
min_encoder_length=ENCODER_LENGTH // 2,
max_encoder_length=ENCODER_LENGTH,
min_prediction_length=1,
max_prediction_length=PREDICTION_LENGTH,
static_categoricals=["ticker"],
time_varying_known_reals=KNOWN_REALS + ["time_idx"],
time_varying_unknown_reals=UNKNOWN_REALS,
target_normalizer=None, # already normalized
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
allow_missing_timesteps=True,
)
def build_tft(training_dataset: TimeSeriesDataSet) -> TemporalFusionTransformer:
return TemporalFusionTransformer.from_dataset(
training_dataset,
learning_rate=LR,
hidden_size=HIDDEN_SIZE,
attention_head_size=ATTENTION_HEAD_SIZE,
dropout=DROPOUT,
hidden_continuous_size=HIDDEN_CONTINUOUS_SIZE,
loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
log_interval=10,
reduce_on_plateau_patience=4,
optimizer="adam",
)
# ── DDG-DA drift predictor training (reused from train.py logic) ──────────────
def train_ddg_da_predictor(df: pd.DataFrame) -> None:
"""Train the lightweight MLP drift predictor for DDG-DA."""
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
SNAPSHOT_WINDOW = 20
K_HISTORY = 8
N_FEATURES = len(UNKNOWN_REALS) + len(KNOWN_REALS) # 11
SNAPSHOT_DIM = N_FEATURES * 4 # 44
def compute_snapshot(arr: np.ndarray) -> np.ndarray:
stats = []
for f in range(arr.shape[1]):
col = arr[:, f].astype(np.float64)
mu = col.mean(); sigma = col.std() or 1e-8
skew = float(((col - mu) ** 3).mean() / sigma ** 3)
kurt = float(((col - mu) ** 4).mean() / sigma ** 4) - 3.0
stats.extend([float(mu), float(sigma), skew, kurt])
return np.array(stats, dtype=np.float32)
feature_cols = UNKNOWN_REALS + KNOWN_REALS
all_X, all_y = [], []
for ticker, grp in df.groupby("ticker"):
feat = grp[feature_cols].values.astype(np.float32)
k = len(feat) // SNAPSHOT_WINDOW
if k < K_HISTORY + 1:
continue
snaps = np.stack([compute_snapshot(feat[i*SNAPSHOT_WINDOW:(i+1)*SNAPSHOT_WINDOW]) for i in range(k)])
for i in range(len(snaps) - K_HISTORY):
all_X.append(snaps[i:i+K_HISTORY].flatten())
all_y.append(snaps[i+K_HISTORY])
if not all_X:
print("Not enough data for DDG-DA training, skipping.")
return
X = np.array(all_X, dtype=np.float32)
y = np.array(all_y, dtype=np.float32)
idx = np.random.permutation(len(X))
X, y = X[idx], y[idx]
split = int(len(X) * 0.9)
train_ds = TensorDataset(torch.tensor(X[:split]), torch.tensor(y[:split]))
val_ds = TensorDataset(torch.tensor(X[split:]), torch.tensor(y[split:]))
train_dl = DataLoader(train_ds, batch_size=256, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=256)
mlp = nn.Sequential(
nn.Linear(K_HISTORY * SNAPSHOT_DIM, 128),
nn.ELU(),
nn.Dropout(0.1),
nn.Linear(128, SNAPSHOT_DIM),
)
opt = torch.optim.Adam(mlp.parameters(), lr=1e-3)
crit = nn.MSELoss()
best_val = float("inf")
print(f"\n--- Training DDG-DA MLP ({len(X):,} pairs) ---")
for epoch in range(1, 31):
mlp.train()
tl = sum(crit(mlp(xb), yb).item() * len(xb) for xb, yb in train_dl) / len(X[:split])
mlp.eval()
with torch.no_grad():
vl = sum(crit(mlp(xb), yb).item() * len(xb) for xb, yb in val_dl) / len(X[split:])
print(f" Epoch {epoch:2d}/30 train={tl:.6f} val={vl:.6f}")
if vl < best_val:
best_val = vl
torch.save(mlp.state_dict(), DDG_DA_PATH)
print(f"DDG-DA saved β†’ {DDG_DA_PATH}")
# ── Upload to HF Hub ──────────────────────────────────────────────────────────
def upload_to_hub(checkpoint_path: str) -> None:
hf_token = os.environ.get("HF_TOKEN", "")
hf_repo = os.environ.get("HF_MODEL_REPO", "")
if not hf_token or not hf_repo:
print("HF_TOKEN / HF_MODEL_REPO not set β€” skipping upload.")
return
from huggingface_hub import HfApi
api = HfApi(token=hf_token)
api.create_repo(repo_id=hf_repo, repo_type="model", exist_ok=True, private=True)
for local, remote in [
(checkpoint_path, "tft_stock.ckpt"),
(DATASET_PARAMS_PATH, "dataset_params.pt"),
(DDG_DA_PATH, "ddg_da.pt"),
]:
if os.path.exists(local):
api.upload_file(path_or_fileobj=local, path_in_repo=remote, repo_id=hf_repo, repo_type="model",
commit_message=f"Colab retrain: {remote}")
print(f"Uploaded {remote} β†’ {hf_repo}")
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
print("=" * 60)
print("StockPro TFT Training (pytorch-forecasting)")
print("=" * 60)
# 1. Collect data
df = collect_data(IDX_TICKERS, years=5)
df.to_parquet(os.path.join(MODEL_DIR, "training_data.parquet"), index=False)
print(f"Data saved β†’ {MODEL_DIR}/training_data.parquet")
# 2. Build datasets
training_ds = build_dataset(df, predict=False)
# Save dataset parameters (fitted categorical encoders, scalers, config)
# These are required by the HF Spaces inference server at runtime.
torch.save(training_ds.get_parameters(), DATASET_PARAMS_PATH)
print(f"Dataset params saved β†’ {DATASET_PARAMS_PATH}")
val_ds = training_ds.from_dataset(training_ds, df, predict=True, stop_randomization=True)
train_dl = training_ds.to_dataloader(train=True, batch_size=BATCH_SIZE, num_workers=2)
val_dl = val_ds.to_dataloader(train=False, batch_size=BATCH_SIZE * 2, num_workers=2)
# 3. Build TFT model
tft = build_tft(training_ds)
print(f"\nTFT parameters: {sum(p.numel() for p in tft.parameters()):,}")
# 4. Train with Lightning
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
callbacks = [
EarlyStopping(monitor="val_loss", patience=5, mode="min"),
ModelCheckpoint(
dirpath=CHECKPOINT_DIR,
filename="tft-{epoch:02d}-{val_loss:.4f}",
monitor="val_loss",
mode="min",
save_top_k=1,
),
]
trainer = pl.Trainer(
max_epochs=MAX_EPOCHS,
accelerator="auto", # GPU if available
gradient_clip_val=GRADIENT_CLIP,
callbacks=callbacks,
enable_progress_bar=True,
log_every_n_steps=10,
)
print("\n--- Training TFT ---")
trainer.fit(tft, train_dataloaders=train_dl, val_dataloaders=val_dl)
# 5. Copy best checkpoint
best_ckpt = callbacks[1].best_model_path
if best_ckpt and os.path.exists(best_ckpt):
import shutil
shutil.copy2(best_ckpt, FINAL_MODEL_PATH)
print(f"\nBest model β†’ {FINAL_MODEL_PATH}")
# 6. Train DDG-DA
train_ddg_da_predictor(df)
# 7. Upload to HF Hub
upload_to_hub(FINAL_MODEL_PATH)
print("\nDone!")
if __name__ == "__main__":
main()