Spaces:
Running
Running
| """ | |
| Weekly retraining script for the TFT stock predictor. | |
| Run: python -m scripts.train | |
| Trains on all IDX tickers fetched from IndoPremier. | |
| Uses multi-horizon quantile loss (pinball loss) for 3 quantiles. | |
| Saves best model to models/tft_stock.pt and uploads to HF Hub. | |
| """ | |
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, TensorDataset | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from app.services.data_fetcher import IDX_TICKERS, fetch_ohlcv | |
| from app.services.feature_engineer import ( | |
| build_features, | |
| make_sequences, | |
| SEQUENCE_LEN, | |
| FORECAST_HORIZON, | |
| N_FEATURES, | |
| ) | |
| from app.services.concept_drift import ( | |
| extract_snapshots_from_series, | |
| SNAPSHOT_DIM, | |
| K_HISTORY, | |
| ) | |
| from app.models.tft_predictor import StockTFT | |
| from app.models.ddg_da import DriftPredictorMLP | |
| MODEL_DIR = os.path.join(os.path.dirname(__file__), "..", "models") | |
| MODEL_PATH = os.path.join(MODEL_DIR, "tft_stock.pt") | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| EPOCHS = 50 | |
| BATCH_SIZE = 32 | |
| LR = 5e-4 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def quantile_loss( | |
| preds: torch.Tensor, | |
| targets: torch.Tensor, | |
| quantiles: list[float] = [0.1, 0.5, 0.9], | |
| ) -> torch.Tensor: | |
| """ | |
| Pinball (quantile) loss. | |
| preds: (batch, horizon, n_quantiles) | |
| targets: (batch, horizon) | |
| """ | |
| total = torch.tensor(0.0, device=preds.device) | |
| for i, q in enumerate(quantiles): | |
| errors = targets - preds[:, :, i] | |
| total += torch.where(errors >= 0, q * errors, (q - 1) * errors).mean() | |
| return total / len(quantiles) | |
| def make_multihorizon_targets(close_norm: np.ndarray, horizon: int) -> np.ndarray: | |
| """ | |
| For each time step t, target is close_norm[t+1 : t+1+horizon]. | |
| Returns (T - horizon, horizon) array. | |
| """ | |
| targets = [] | |
| for i in range(len(close_norm) - horizon): | |
| targets.append(close_norm[i + 1 : i + 1 + horizon]) | |
| return np.array(targets, dtype=np.float32) | |
| def collect_training_data() -> tuple[np.ndarray, np.ndarray]: | |
| all_X, all_y = [], [] | |
| print(f"Fetching data for {len(IDX_TICKERS)} IDX tickers from IndoPremier...") | |
| for i, ticker in enumerate(IDX_TICKERS): | |
| data = fetch_ohlcv(ticker, period="5y") | |
| if data is None: | |
| print(f" [{i+1}/{len(IDX_TICKERS)}] {ticker}: no data, skipping") | |
| continue | |
| features = build_features(data["closes"], data["volumes"], data["timestamps"]) | |
| # close_norm is feature column 0 | |
| close_norm = features[:, 0] | |
| multihorizon_targets = make_multihorizon_targets(close_norm, FORECAST_HORIZON) | |
| # Align: features[:-FORECAST_HORIZON] β multihorizon_targets | |
| aligned_features = features[: len(multihorizon_targets)] | |
| X, y = make_sequences(aligned_features, multihorizon_targets, SEQUENCE_LEN) | |
| if len(X) == 0: | |
| print(f" [{i+1}/{len(IDX_TICKERS)}] {ticker}: too short, skipping") | |
| continue | |
| all_X.append(X) | |
| all_y.append(y) | |
| print(f" [{i+1}/{len(IDX_TICKERS)}] {ticker}: {len(X)} sequences") | |
| if not all_X: | |
| raise RuntimeError("No training data collected") | |
| return np.concatenate(all_X), np.concatenate(all_y) | |
| def train(): | |
| print(f"Training TFT on {DEVICE}") | |
| X, y = collect_training_data() | |
| print(f"Total sequences: {len(X)}, features: {N_FEATURES}, horizon: {FORECAST_HORIZON}") | |
| idx = np.random.permutation(len(X)) | |
| X, y = X[idx], y[idx] | |
| split = int(len(X) * 0.9) | |
| X_train, X_val = X[:split], X[split:] | |
| y_train, y_val = y[:split], y[split:] | |
| train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train)) | |
| val_ds = TensorDataset(torch.tensor(X_val), torch.tensor(y_val)) | |
| train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) | |
| val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE) | |
| model = StockTFT(input_size=N_FEATURES, forecast_horizon=FORECAST_HORIZON).to(DEVICE) | |
| opt = torch.optim.Adam(model.parameters(), lr=LR) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS) | |
| best_val_loss = float("inf") | |
| for epoch in range(1, EPOCHS + 1): | |
| model.train() | |
| train_loss = 0.0 | |
| for xb, yb in train_dl: | |
| xb, yb = xb.to(DEVICE), yb.to(DEVICE) | |
| opt.zero_grad() | |
| preds, _ = model(xb) # (B, HORIZON, 3) | |
| loss = quantile_loss(preds, yb) | |
| loss.backward() | |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| opt.step() | |
| train_loss += loss.item() * len(xb) | |
| train_loss /= len(X_train) | |
| model.eval() | |
| val_loss = 0.0 | |
| with torch.no_grad(): | |
| for xb, yb in val_dl: | |
| xb, yb = xb.to(DEVICE), yb.to(DEVICE) | |
| preds, _ = model(xb) | |
| val_loss += quantile_loss(preds, yb).item() * len(xb) | |
| val_loss /= len(X_val) | |
| scheduler.step() | |
| print(f"Epoch {epoch:2d}/{EPOCHS} train={train_loss:.4f} val={val_loss:.4f}") | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| torch.save(model.state_dict(), MODEL_PATH) | |
| print(f" β Saved best model (val={val_loss:.4f})") | |
| print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}") | |
| print(f"Model saved to {MODEL_PATH}") | |
| _upload_to_hub(MODEL_PATH) | |
| def _upload_to_hub(model_path: str) -> None: | |
| from app.config import MODEL_REPO, HF_TOKEN | |
| if not MODEL_REPO or not HF_TOKEN: | |
| print("HF_MODEL_REPO / HF_TOKEN not set β skipping HF Hub upload.") | |
| return | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=HF_TOKEN) | |
| api.create_repo(repo_id=MODEL_REPO, repo_type="model", exist_ok=True, private=True) | |
| api.upload_file( | |
| path_or_fileobj=model_path, | |
| path_in_repo="tft_stock.pt", | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| commit_message="Weekly TFT retrain via GitHub Actions", | |
| ) | |
| print(f"Model uploaded to HF Hub: {MODEL_REPO}/tft_stock.pt") | |
| except Exception as e: | |
| print(f"HF Hub upload failed: {e}") | |
| # ββ DDG-DA training βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DDG_DA_PATH = os.path.join(MODEL_DIR, "ddg_da.pt") | |
| DDG_DA_EPOCHS = 30 | |
| DDG_DA_BATCH = 256 | |
| DDG_DA_LR = 1e-3 | |
| def collect_drift_snapshots() -> list[np.ndarray]: | |
| """ | |
| Return a list of per-ticker (K_ticker, SNAPSHOT_DIM) snapshot arrays. | |
| Cross-ticker boundaries are kept separate to avoid contaminated training pairs. | |
| """ | |
| per_ticker: list[np.ndarray] = [] | |
| print(f"Collecting drift snapshots for {len(IDX_TICKERS)} tickers...") | |
| for i, ticker in enumerate(IDX_TICKERS): | |
| data = fetch_ohlcv(ticker, period="5y") | |
| if data is None: | |
| continue | |
| features = build_features(data["closes"], data["volumes"], data["timestamps"]) | |
| snaps = extract_snapshots_from_series(features) | |
| if len(snaps) >= K_HISTORY + 1: # need at least K+1 for one (X,y) pair | |
| per_ticker.append(snaps) | |
| if (i + 1) % 100 == 0: | |
| print(f" Snapshots: {i+1}/{len(IDX_TICKERS)} tickers processed, {len(per_ticker)} valid") | |
| print(f" Collected {len(per_ticker)} tickers with sufficient snapshot history") | |
| return per_ticker | |
| def build_ddg_da_dataset( | |
| per_ticker_snapshots: list[np.ndarray], | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Build (X, y) pairs for DDG-DA MLP training. | |
| Window slides WITHIN each ticker's snapshots to avoid cross-ticker contamination. | |
| X: (N, K_HISTORY * SNAPSHOT_DIM) | |
| y: (N, SNAPSHOT_DIM) | |
| """ | |
| all_X, all_y = [], [] | |
| for snaps in per_ticker_snapshots: | |
| k = len(snaps) | |
| for i in range(k - K_HISTORY): | |
| x_window = snaps[i : i + K_HISTORY].flatten() # (K * 44,) | |
| y_target = snaps[i + K_HISTORY] # (44,) | |
| all_X.append(x_window) | |
| all_y.append(y_target) | |
| if not all_X: | |
| raise RuntimeError("No DDG-DA training pairs collected") | |
| return np.array(all_X, dtype=np.float32), np.array(all_y, dtype=np.float32) | |
| def train_ddg_da(per_ticker_snapshots: list[np.ndarray]) -> None: | |
| print(f"\n--- Training DDG-DA drift predictor (device={DEVICE}) ---") | |
| X, y = build_ddg_da_dataset(per_ticker_snapshots) | |
| print(f" DDG-DA training pairs: {len(X)}, snapshot_dim={SNAPSHOT_DIM}, k_history={K_HISTORY}") | |
| idx = np.random.permutation(len(X)) | |
| X, y = X[idx], y[idx] | |
| split = int(len(X) * 0.9) | |
| X_train, X_val = X[:split], X[split:] | |
| y_train, y_val = y[:split], y[split:] | |
| train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train)) | |
| val_ds = TensorDataset(torch.tensor(X_val), torch.tensor(y_val)) | |
| train_dl = DataLoader(train_ds, batch_size=DDG_DA_BATCH, shuffle=True) | |
| val_dl = DataLoader(val_ds, batch_size=DDG_DA_BATCH) | |
| mlp = DriftPredictorMLP(k_history=K_HISTORY, snapshot_dim=SNAPSHOT_DIM).to(DEVICE) | |
| opt = torch.optim.Adam(mlp.parameters(), lr=DDG_DA_LR) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=DDG_DA_EPOCHS) | |
| criterion = nn.MSELoss() | |
| best_val = float("inf") | |
| for epoch in range(1, DDG_DA_EPOCHS + 1): | |
| mlp.train() | |
| train_loss = 0.0 | |
| for xb, yb in train_dl: | |
| xb, yb = xb.to(DEVICE), yb.to(DEVICE) | |
| opt.zero_grad() | |
| pred = mlp(xb) | |
| loss = criterion(pred, yb) | |
| loss.backward() | |
| nn.utils.clip_grad_norm_(mlp.parameters(), 1.0) | |
| opt.step() | |
| train_loss += loss.item() * len(xb) | |
| train_loss /= len(X_train) | |
| mlp.eval() | |
| val_loss = 0.0 | |
| with torch.no_grad(): | |
| for xb, yb in val_dl: | |
| xb, yb = xb.to(DEVICE), yb.to(DEVICE) | |
| val_loss += criterion(mlp(xb), yb).item() * len(xb) | |
| val_loss /= len(X_val) | |
| scheduler.step() | |
| print(f" DDG-DA Epoch {epoch:2d}/{DDG_DA_EPOCHS} train={train_loss:.6f} val={val_loss:.6f}") | |
| if val_loss < best_val: | |
| best_val = val_loss | |
| torch.save(mlp.state_dict(), DDG_DA_PATH) | |
| print(f" β Saved best DDG-DA model (val={val_loss:.6f})") | |
| print(f"DDG-DA training complete. Saved to {DDG_DA_PATH}") | |
| _upload_ddg_da_to_hub(DDG_DA_PATH) | |
| def _upload_ddg_da_to_hub(model_path: str) -> None: | |
| from app.config import MODEL_REPO, HF_TOKEN | |
| if not MODEL_REPO or not HF_TOKEN: | |
| print("HF_MODEL_REPO / HF_TOKEN not set β skipping DDG-DA HF Hub upload.") | |
| return | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=HF_TOKEN) | |
| api.upload_file( | |
| path_or_fileobj=model_path, | |
| path_in_repo="ddg_da.pt", | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| commit_message="Weekly DDG-DA retrain via GitHub Actions", | |
| ) | |
| print(f"DDG-DA model uploaded to HF Hub: {MODEL_REPO}/ddg_da.pt") | |
| except Exception as e: | |
| print(f"DDG-DA HF Hub upload failed: {e}") | |
| # ββ Entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| train() | |
| # Train DDG-DA drift predictor after TFT | |
| snapshots = collect_drift_snapshots() | |
| train_ddg_da(snapshots) | |