StockPro TFT β€” Indonesian Stock Price Forecasting

Temporal Fusion Transformer (TFT) for short-term quantile price forecasting on Indonesian Stock Exchange (IDX) equities, with a companion DDG-DA concept drift detector.

Trained on 956 IDX stocks (2021–2026) using pytorch-forecasting and PyTorch Lightning.


Model Description

This repository contains three model artifacts used by the StockPro ML backend:

File Size Description
tft_stock.ckpt 4.6 MB PyTorch Lightning TFT checkpoint
dataset_params.pt 63 KB TimeSeriesDataSet parameters (fitted categorical encoders, configuration)
ddg_da.pt 187 KB DDG-DA drift predictor MLP weights

What the model does

Given the last 60 trading days of OHLCV data for any IDX stock, the model predicts the next 1–30 trading days with three quantile outputs (10th, 50th, 90th percentile), giving calibrated prediction intervals alongside the point forecast.

The companion DDG-DA (Data Distribution Generation for Predictable Concept Drift) module detects when the current market regime has shifted from the training distribution, and adjusts forecast confidence accordingly.


Architecture

TFT (Temporal Fusion Transformer)

Built with pytorch-forecasting.TemporalFusionTransformer:

Hyperparameter Value
hidden_size 64
attention_head_size 4
dropout 0.1
hidden_continuous_size 16
Encoder length 60 days
Max prediction length 30 days
Loss QuantileLoss([0.1, 0.5, 0.9])
Optimizer Adam, lr=1e-3
Gradient clip 0.1

DDG-DA Drift Predictor MLP

A lightweight 2-layer MLP (~51K parameters) that predicts the next distribution snapshot from 8 rolling windows of statistical moments:

Input: (8 Γ— 44) = 352-dim  β†’  Linear(352, 128) β†’ ELU β†’ Dropout(0.1) β†’ Linear(128, 44)

Each snapshot is a 44-dim vector of [mean, std, skewness, kurtosis] per feature, computed over a 20-day rolling window. Drift is flagged when the mean absolute z-score across all 44 dimensions exceeds 1.8Οƒ.


Input Features

The model uses 11 normalized input features split into two groups:

Unknown Reals (encoder only β€” past observations)

Feature Description Normalization
close_norm Closing price Rolling 30-day z-score
volume_norm Trading volume Rolling 30-day z-score
rsi RSI(14) Normalized to [0, 1]
macd_norm MACD histogram Divided by rolling std
bb_width Bollinger Band width 2Οƒ / SMA(20)
atr_norm ATR(14) Divided by close price
obv_norm On-Balance Volume Z-score normalized

Known Reals (encoder + decoder β€” calendar features)

Feature Description
day_sin Sine encoding of day-of-week (period=5)
day_cos Cosine encoding of day-of-week (period=5)
month_sin Sine encoding of month (period=12)
month_cos Cosine encoding of month (period=12)

Training Details

  • Data: 956 IDX stocks, 5 years of daily OHLCV data from IndoPremier (2021–2026)
  • Platform: Google Colab (T4 GPU)
  • Framework: pytorch-forecasting + lightning
  • Epochs: up to 50 (EarlyStopping patience=5, monitor=val_loss)
  • Batch size: 64
  • Validation: Last 30 days of each ticker held out
  • Categorical: ticker ID encoded as a static categorical (UNK embedding for unseen tickers)

DDG-DA was trained jointly on the same ticker population, using non-overlapping 20-day windows to construct snapshot prediction pairs.


Usage

Installation

pip install pytorch-forecasting lightning huggingface_hub pandas numpy
# CPU-only torch (smaller, sufficient for inference):
pip install torch --index-url https://download.pytorch.org/whl/cpu

Load model and run inference

import torch
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from huggingface_hub import hf_hub_download
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet

REPO_ID = "will702/stockpro-tft"
ENCODER_LENGTH = 60
FORECAST_HORIZON = 30

# 1. Download artifacts
ckpt_path    = hf_hub_download(REPO_ID, "tft_stock.ckpt")
params_path  = hf_hub_download(REPO_ID, "dataset_params.pt")

# 2. Load model (CPU)
model = TemporalFusionTransformer.load_from_checkpoint(
    ckpt_path,
    map_location=lambda storage, loc: storage.cpu(),
)
model.eval()

ds_params = torch.load(params_path, map_location="cpu", weights_only=False)

# 3. Prepare input β€” replace with your real OHLCV data
# closes, volumes: numpy arrays of length >= 60
# timestamps: unix seconds (int64)
closes     = np.random.uniform(3000, 5000, size=120).astype(np.float32)
volumes    = np.random.uniform(1e7, 5e7, size=120).astype(np.float64)
timestamps = np.array([
    int((datetime(2025, 1, 1) + timedelta(days=i)).timestamp())
    for i in range(120)
], dtype=np.int64)

# 4. Feature engineering (rolling z-score normalization)
def build_features(closes, volumes, timestamps):
    s = pd.Series(closes)
    rm = s.rolling(30, min_periods=1).mean()
    rs = s.rolling(30, min_periods=1).std(ddof=0).fillna(1).clip(lower=1e-6)
    close_norm = ((s - rm) / rs).values

    sv = pd.Series(volumes)
    vm = sv.rolling(30, min_periods=1).mean()
    vs = sv.rolling(30, min_periods=1).std(ddof=0).fillna(1).clip(lower=1e-6)
    volume_norm = ((sv - vm) / vs).values

    delta = np.diff(closes, prepend=closes[0])
    gain = pd.Series(np.where(delta > 0, delta, 0.)).ewm(alpha=1/14, adjust=False).mean().values
    loss = pd.Series(np.where(delta < 0, -delta, 0.)).ewm(alpha=1/14, adjust=False).mean().values
    rs_r = np.where(loss == 0, 100., gain / (loss + 1e-9))
    rsi = np.clip(rs_r / (1 + rs_r), 0, 1)

    sp = pd.Series(closes)
    macd_raw = (sp.ewm(span=12, adjust=False).mean() - sp.ewm(span=26, adjust=False).mean()
                - sp.ewm(span=26, adjust=False).mean().ewm(span=9, adjust=False).mean()).values
    macd_norm = macd_raw / (np.std(macd_raw) or 1)

    sma = sp.rolling(20, min_periods=1).mean()
    std20 = sp.rolling(20, min_periods=1).std(ddof=0).fillna(0)
    bb_width = (2 * std20 / sma.clip(lower=1e-6)).fillna(0).values

    prev_c = np.roll(closes, 1); prev_c[0] = closes[0]
    tr = np.maximum(np.abs(closes - prev_c), np.abs(closes - prev_c))
    atr_norm = pd.Series(tr).ewm(alpha=1/14, adjust=False).mean().values / (closes + 1e-9)

    obv = np.cumsum(np.sign(np.diff(closes, prepend=closes[0])) * volumes)
    obv_norm = (obv - np.mean(obv)) / (np.std(obv) or 1)

    dt = pd.to_datetime(timestamps, unit="s")
    day_sin   = np.sin(2 * np.pi * dt.dayofweek.values / 5)
    day_cos   = np.cos(2 * np.pi * dt.dayofweek.values / 5)
    month_sin = np.sin(2 * np.pi * dt.month.values / 12)
    month_cos = np.cos(2 * np.pi * dt.month.values / 12)

    return np.stack([close_norm, volume_norm, rsi, macd_norm, bb_width,
                     day_sin, day_cos, atr_norm, obv_norm, month_sin, month_cos], axis=1).astype(np.float32)

features = build_features(closes, volumes, timestamps)[-ENCODER_LENGTH:]
ts_slice = timestamps[-ENCODER_LENGTH:]
dates = [datetime.utcfromtimestamp(int(ts)) for ts in ts_slice]

UNKNOWN_REALS = ["close_norm", "volume_norm", "rsi", "macd_norm", "bb_width", "atr_norm", "obv_norm"]
KNOWN_REALS   = ["day_sin", "day_cos", "month_sin", "month_cos"]
FEAT_IDX      = {c: i for i, c in enumerate(["close_norm","volume_norm","rsi","macd_norm","bb_width","day_sin","day_cos","atr_norm","obv_norm","month_sin","month_cos"])}

# Build encoder rows
rows = []
for i, (feat_row, dt) in enumerate(zip(features, dates)):
    row = {"ticker": "BBRI", "time_idx": i, "date": dt}
    for col in UNKNOWN_REALS + KNOWN_REALS:
        row[col] = float(feat_row[FEAT_IDX[col]])
    rows.append(row)
encoder_df = pd.DataFrame(rows)

# Build future decoder rows
last_date = dates[-1]
future_rows = []
for i in range(1, FORECAST_HORIZON + 1):
    fd = last_date + timedelta(days=i)
    future_rows.append({
        "ticker": "BBRI", "time_idx": ENCODER_LENGTH + i - 1, "date": fd,
        "close_norm": 0., "volume_norm": 0., "rsi": 0.5,
        "macd_norm": 0., "bb_width": 0., "atr_norm": 0., "obv_norm": 0.,
        "day_sin":   float(np.sin(2 * np.pi * fd.weekday() / 5)),
        "day_cos":   float(np.cos(2 * np.pi * fd.weekday() / 5)),
        "month_sin": float(np.sin(2 * np.pi * fd.month / 12)),
        "month_cos": float(np.cos(2 * np.pi * fd.month / 12)),
    })
full_df = pd.concat([encoder_df, pd.DataFrame(future_rows)], ignore_index=True)

# 5. Reconstruct dataset and predict
days = 7  # forecast horizon (1–30)
pred_ds = TimeSeriesDataSet.from_parameters(
    ds_params, full_df, predict=True, stop_randomization=True,
    min_encoder_length=ENCODER_LENGTH, max_encoder_length=ENCODER_LENGTH,
    min_prediction_length=FORECAST_HORIZON, max_prediction_length=FORECAST_HORIZON,
)
pred_dl = pred_ds.to_dataloader(train=False, batch_size=1, num_workers=0)

with torch.no_grad():
    raw = model.predict(pred_dl, mode="quantiles", return_x=False)

preds = raw.squeeze(0).cpu().numpy()[:days]  # (days, 3)

# Denormalize from rolling z-score
roll_mean = float(np.mean(closes[-30:]))
roll_std  = float(np.std(closes[-30:])) or 1.0
q10 = [round(float(z * roll_std + roll_mean), 2) for z in preds[:, 0]]
q50 = [round(float(z * roll_std + roll_mean), 2) for z in preds[:, 1]]
q90 = [round(float(z * roll_std + roll_mean), 2) for z in preds[:, 2]]

print(f"7-day forecast for BBRI:")
for i, (lo, mid, hi) in enumerate(zip(q10, q50, q90), 1):
    print(f"  D+{i}: {mid:,.0f}  [{lo:,.0f} – {hi:,.0f}]")

Drift Detection with DDG-DA

from huggingface_hub import hf_hub_download
import torch, torch.nn as nn, numpy as np

ddg_path = hf_hub_download("will702/stockpro-tft", "ddg_da.pt")

class DriftMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(352, 128), nn.ELU(), nn.Dropout(0.1), nn.Linear(128, 44))
    def forward(self, x):
        return self.net(x)

mlp = DriftMLP()
mlp.load_state_dict(torch.load(ddg_path, map_location="cpu", weights_only=True))
mlp.eval()

# Compute feature distribution snapshots
# features: (T, 11) matrix from build_features() above
def compute_snapshot(arr):  # arr: (window, 11)
    stats = []
    for f in range(arr.shape[1]):
        col = arr[:, f].astype(np.float64)
        mu, sigma = col.mean(), col.std() or 1e-8
        skew = float(((col - mu)**3).mean() / sigma**3)
        kurt = float(((col - mu)**4).mean() / sigma**4) - 3.0
        stats.extend([float(mu), float(sigma), skew, kurt])
    return np.array(stats, dtype=np.float32)

SNAPSHOT_WINDOW, K_HISTORY = 20, 8
T = len(features)
snapshots = [compute_snapshot(features[i*20:(i+1)*20]) for i in range(T // 20)]

if len(snapshots) >= K_HISTORY + 1:
    history = np.stack(snapshots[-K_HISTORY:])       # (8, 44)
    current = snapshots[-1]
    ref     = np.stack(snapshots[:-1])               # (K, 44)

    # Compute drift score (mean |z-score| across all 44 dims)
    ref_mean = ref.mean(0); ref_std = ref.std(0) + 1e-8
    drift_score = float(np.abs((current - ref_mean) / ref_std).mean())
    drift_detected = drift_score > 1.8

    print(f"Drift score: {drift_score:.3f}  |  Detected: {drift_detected}")

Evaluation Results

Preliminary backtest on 2 tickers (BBCA, BBRI) over a 3-month test window (2025-12 – 2026-03), 26 samples per horizon:

Horizon MAPE Directional Acc. Quantile Coverage (10–90%)
1 day 1.92% 38.5% 61.5%
3 days 1.96% 50.0% 65.4%
7 days 3.14% 57.7% 69.2%
14 days 4.13% 50.0% 65.4%
30 days 6.83% 57.7% 61.5%

⚠️ Caveat: Results are preliminary β€” only 2 tickers were included in this backtest run. A broader evaluation across the full 956-stock universe is needed before drawing conclusions about generalization.


Limitations

  • IDX-only: Trained exclusively on Indonesian equities. Not validated on other markets.
  • No guarantee: Stock prices are influenced by news, macroeconomic events, and factors not captured by OHLCV + technical indicators alone.
  • Backtest coverage: The evaluation above covers only 2 tickers. Performance varies significantly by stock.
  • Normalization dependency: Outputs are rolling z-scores β€” callers must denormalize using the same 30-day rolling mean/std window applied to input prices.
  • Cold-start: Unknown tickers (not seen during training) fall back to the UNK embedding. Quality may degrade for small-cap, illiquid stocks.
  • Not financial advice: This model is intended for research and educational purposes only.

License

Apache 2.0 β€” see LICENSE.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support