stockpro-ml / scripts /train.py
will702's picture
StockPro ML backend with pytorch-forecasting TFT
9334ec6
"""
Weekly retraining script for the TFT stock predictor.
Run: python -m scripts.train
Trains on all IDX tickers fetched from IndoPremier.
Uses multi-horizon quantile loss (pinball loss) for 3 quantiles.
Saves best model to models/tft_stock.pt and uploads to HF Hub.
"""
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from app.services.data_fetcher import IDX_TICKERS, fetch_ohlcv
from app.services.feature_engineer import (
build_features,
make_sequences,
SEQUENCE_LEN,
FORECAST_HORIZON,
N_FEATURES,
)
from app.services.concept_drift import (
extract_snapshots_from_series,
SNAPSHOT_DIM,
K_HISTORY,
)
from app.models.tft_predictor import StockTFT
from app.models.ddg_da import DriftPredictorMLP
MODEL_DIR = os.path.join(os.path.dirname(__file__), "..", "models")
MODEL_PATH = os.path.join(MODEL_DIR, "tft_stock.pt")
os.makedirs(MODEL_DIR, exist_ok=True)
EPOCHS = 50
BATCH_SIZE = 32
LR = 5e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def quantile_loss(
preds: torch.Tensor,
targets: torch.Tensor,
quantiles: list[float] = [0.1, 0.5, 0.9],
) -> torch.Tensor:
"""
Pinball (quantile) loss.
preds: (batch, horizon, n_quantiles)
targets: (batch, horizon)
"""
total = torch.tensor(0.0, device=preds.device)
for i, q in enumerate(quantiles):
errors = targets - preds[:, :, i]
total += torch.where(errors >= 0, q * errors, (q - 1) * errors).mean()
return total / len(quantiles)
def make_multihorizon_targets(close_norm: np.ndarray, horizon: int) -> np.ndarray:
"""
For each time step t, target is close_norm[t+1 : t+1+horizon].
Returns (T - horizon, horizon) array.
"""
targets = []
for i in range(len(close_norm) - horizon):
targets.append(close_norm[i + 1 : i + 1 + horizon])
return np.array(targets, dtype=np.float32)
def collect_training_data() -> tuple[np.ndarray, np.ndarray]:
all_X, all_y = [], []
print(f"Fetching data for {len(IDX_TICKERS)} IDX tickers from IndoPremier...")
for i, ticker in enumerate(IDX_TICKERS):
data = fetch_ohlcv(ticker, period="5y")
if data is None:
print(f" [{i+1}/{len(IDX_TICKERS)}] {ticker}: no data, skipping")
continue
features = build_features(data["closes"], data["volumes"], data["timestamps"])
# close_norm is feature column 0
close_norm = features[:, 0]
multihorizon_targets = make_multihorizon_targets(close_norm, FORECAST_HORIZON)
# Align: features[:-FORECAST_HORIZON] β†’ multihorizon_targets
aligned_features = features[: len(multihorizon_targets)]
X, y = make_sequences(aligned_features, multihorizon_targets, SEQUENCE_LEN)
if len(X) == 0:
print(f" [{i+1}/{len(IDX_TICKERS)}] {ticker}: too short, skipping")
continue
all_X.append(X)
all_y.append(y)
print(f" [{i+1}/{len(IDX_TICKERS)}] {ticker}: {len(X)} sequences")
if not all_X:
raise RuntimeError("No training data collected")
return np.concatenate(all_X), np.concatenate(all_y)
def train():
print(f"Training TFT on {DEVICE}")
X, y = collect_training_data()
print(f"Total sequences: {len(X)}, features: {N_FEATURES}, horizon: {FORECAST_HORIZON}")
idx = np.random.permutation(len(X))
X, y = X[idx], y[idx]
split = int(len(X) * 0.9)
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]
train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
val_ds = TensorDataset(torch.tensor(X_val), torch.tensor(y_val))
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE)
model = StockTFT(input_size=N_FEATURES, forecast_horizon=FORECAST_HORIZON).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
best_val_loss = float("inf")
for epoch in range(1, EPOCHS + 1):
model.train()
train_loss = 0.0
for xb, yb in train_dl:
xb, yb = xb.to(DEVICE), yb.to(DEVICE)
opt.zero_grad()
preds, _ = model(xb) # (B, HORIZON, 3)
loss = quantile_loss(preds, yb)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
train_loss += loss.item() * len(xb)
train_loss /= len(X_train)
model.eval()
val_loss = 0.0
with torch.no_grad():
for xb, yb in val_dl:
xb, yb = xb.to(DEVICE), yb.to(DEVICE)
preds, _ = model(xb)
val_loss += quantile_loss(preds, yb).item() * len(xb)
val_loss /= len(X_val)
scheduler.step()
print(f"Epoch {epoch:2d}/{EPOCHS} train={train_loss:.4f} val={val_loss:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), MODEL_PATH)
print(f" βœ“ Saved best model (val={val_loss:.4f})")
print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}")
print(f"Model saved to {MODEL_PATH}")
_upload_to_hub(MODEL_PATH)
def _upload_to_hub(model_path: str) -> None:
from app.config import MODEL_REPO, HF_TOKEN
if not MODEL_REPO or not HF_TOKEN:
print("HF_MODEL_REPO / HF_TOKEN not set β€” skipping HF Hub upload.")
return
try:
from huggingface_hub import HfApi
api = HfApi(token=HF_TOKEN)
api.create_repo(repo_id=MODEL_REPO, repo_type="model", exist_ok=True, private=True)
api.upload_file(
path_or_fileobj=model_path,
path_in_repo="tft_stock.pt",
repo_id=MODEL_REPO,
repo_type="model",
commit_message="Weekly TFT retrain via GitHub Actions",
)
print(f"Model uploaded to HF Hub: {MODEL_REPO}/tft_stock.pt")
except Exception as e:
print(f"HF Hub upload failed: {e}")
# ── DDG-DA training ───────────────────────────────────────────────────────────
DDG_DA_PATH = os.path.join(MODEL_DIR, "ddg_da.pt")
DDG_DA_EPOCHS = 30
DDG_DA_BATCH = 256
DDG_DA_LR = 1e-3
def collect_drift_snapshots() -> list[np.ndarray]:
"""
Return a list of per-ticker (K_ticker, SNAPSHOT_DIM) snapshot arrays.
Cross-ticker boundaries are kept separate to avoid contaminated training pairs.
"""
per_ticker: list[np.ndarray] = []
print(f"Collecting drift snapshots for {len(IDX_TICKERS)} tickers...")
for i, ticker in enumerate(IDX_TICKERS):
data = fetch_ohlcv(ticker, period="5y")
if data is None:
continue
features = build_features(data["closes"], data["volumes"], data["timestamps"])
snaps = extract_snapshots_from_series(features)
if len(snaps) >= K_HISTORY + 1: # need at least K+1 for one (X,y) pair
per_ticker.append(snaps)
if (i + 1) % 100 == 0:
print(f" Snapshots: {i+1}/{len(IDX_TICKERS)} tickers processed, {len(per_ticker)} valid")
print(f" Collected {len(per_ticker)} tickers with sufficient snapshot history")
return per_ticker
def build_ddg_da_dataset(
per_ticker_snapshots: list[np.ndarray],
) -> tuple[np.ndarray, np.ndarray]:
"""
Build (X, y) pairs for DDG-DA MLP training.
Window slides WITHIN each ticker's snapshots to avoid cross-ticker contamination.
X: (N, K_HISTORY * SNAPSHOT_DIM)
y: (N, SNAPSHOT_DIM)
"""
all_X, all_y = [], []
for snaps in per_ticker_snapshots:
k = len(snaps)
for i in range(k - K_HISTORY):
x_window = snaps[i : i + K_HISTORY].flatten() # (K * 44,)
y_target = snaps[i + K_HISTORY] # (44,)
all_X.append(x_window)
all_y.append(y_target)
if not all_X:
raise RuntimeError("No DDG-DA training pairs collected")
return np.array(all_X, dtype=np.float32), np.array(all_y, dtype=np.float32)
def train_ddg_da(per_ticker_snapshots: list[np.ndarray]) -> None:
print(f"\n--- Training DDG-DA drift predictor (device={DEVICE}) ---")
X, y = build_ddg_da_dataset(per_ticker_snapshots)
print(f" DDG-DA training pairs: {len(X)}, snapshot_dim={SNAPSHOT_DIM}, k_history={K_HISTORY}")
idx = np.random.permutation(len(X))
X, y = X[idx], y[idx]
split = int(len(X) * 0.9)
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]
train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
val_ds = TensorDataset(torch.tensor(X_val), torch.tensor(y_val))
train_dl = DataLoader(train_ds, batch_size=DDG_DA_BATCH, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=DDG_DA_BATCH)
mlp = DriftPredictorMLP(k_history=K_HISTORY, snapshot_dim=SNAPSHOT_DIM).to(DEVICE)
opt = torch.optim.Adam(mlp.parameters(), lr=DDG_DA_LR)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=DDG_DA_EPOCHS)
criterion = nn.MSELoss()
best_val = float("inf")
for epoch in range(1, DDG_DA_EPOCHS + 1):
mlp.train()
train_loss = 0.0
for xb, yb in train_dl:
xb, yb = xb.to(DEVICE), yb.to(DEVICE)
opt.zero_grad()
pred = mlp(xb)
loss = criterion(pred, yb)
loss.backward()
nn.utils.clip_grad_norm_(mlp.parameters(), 1.0)
opt.step()
train_loss += loss.item() * len(xb)
train_loss /= len(X_train)
mlp.eval()
val_loss = 0.0
with torch.no_grad():
for xb, yb in val_dl:
xb, yb = xb.to(DEVICE), yb.to(DEVICE)
val_loss += criterion(mlp(xb), yb).item() * len(xb)
val_loss /= len(X_val)
scheduler.step()
print(f" DDG-DA Epoch {epoch:2d}/{DDG_DA_EPOCHS} train={train_loss:.6f} val={val_loss:.6f}")
if val_loss < best_val:
best_val = val_loss
torch.save(mlp.state_dict(), DDG_DA_PATH)
print(f" βœ“ Saved best DDG-DA model (val={val_loss:.6f})")
print(f"DDG-DA training complete. Saved to {DDG_DA_PATH}")
_upload_ddg_da_to_hub(DDG_DA_PATH)
def _upload_ddg_da_to_hub(model_path: str) -> None:
from app.config import MODEL_REPO, HF_TOKEN
if not MODEL_REPO or not HF_TOKEN:
print("HF_MODEL_REPO / HF_TOKEN not set β€” skipping DDG-DA HF Hub upload.")
return
try:
from huggingface_hub import HfApi
api = HfApi(token=HF_TOKEN)
api.upload_file(
path_or_fileobj=model_path,
path_in_repo="ddg_da.pt",
repo_id=MODEL_REPO,
repo_type="model",
commit_message="Weekly DDG-DA retrain via GitHub Actions",
)
print(f"DDG-DA model uploaded to HF Hub: {MODEL_REPO}/ddg_da.pt")
except Exception as e:
print(f"DDG-DA HF Hub upload failed: {e}")
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
train()
# Train DDG-DA drift predictor after TFT
snapshots = collect_drift_snapshots()
train_ddg_da(snapshots)