stockpro-ml / app /models /tft_predictor.py
will702's picture
fix: patch torchmetrics._apply using CPU probe tensor instead of self.device
ec4688a
"""
pytorch-forecasting TFT inference for IDX stock price prediction.
Loads from Lightning checkpoint (.ckpt) produced by train_colab.py.
Uses pytorch-forecasting's TimeSeriesDataSet + TemporalFusionTransformer.
"""
import os
import numpy as np
import pandas as pd
import torch
from datetime import datetime, timedelta
from typing import Optional
from app.services.feature_engineer import SEQUENCE_LEN, FEATURE_COLS, build_features
FORECAST_HORIZON = 30
ENCODER_LENGTH = SEQUENCE_LEN # 60
QUANTILES = [0.1, 0.5, 0.9]
N_QUANTILES = len(QUANTILES)
TARGET = "close_norm"
KNOWN_REALS = ["day_sin", "day_cos", "month_sin", "month_cos"]
UNKNOWN_REALS = ["close_norm", "volume_norm", "rsi", "macd_norm", "bb_width", "atr_norm", "obv_norm"]
# Column index lookup for build_features() output
_FEAT_IDX = {col: i for i, col in enumerate(FEATURE_COLS)}
# ── Model / params caching ────────────────────────────────────────────────────
_model = None
_model_path_cached: Optional[str] = None
_ds_params: Optional[dict] = None
_ds_params_path_cached: Optional[str] = None
def _maybe_download(filename: str, local_path: str) -> bool:
"""Download a file from HF Hub if not present locally."""
if os.path.exists(local_path):
return True
import app.config as cfg
if not cfg.MODEL_REPO or not cfg.HF_TOKEN:
return False
try:
from huggingface_hub import hf_hub_download
local = hf_hub_download(
repo_id=cfg.MODEL_REPO,
filename=filename,
token=cfg.HF_TOKEN,
local_dir=os.path.dirname(local_path),
)
if local != local_path:
import shutil
shutil.copy2(local, local_path)
return os.path.exists(local_path)
except Exception as e:
print(f"[tft] Could not download {filename} from HF Hub: {e}")
return False
def _patch_torchmetrics_cpu():
"""Patch torchmetrics.Metric._apply to avoid CUDA errors on CPU-only servers.
When a GPU-trained checkpoint is loaded on CPU-only hardware, the torchmetrics
Metric._apply method does `fn(torch.zeros(1, device=self.device))` where
self.device may still be "cuda:0" from the checkpoint. We replace that with
a safe CPU probe so the destination device is inferred without touching CUDA.
"""
try:
import torchmetrics
import torch.nn as nn
_orig = torchmetrics.Metric._apply
def _safe_apply(self, fn):
# Probe destination device via a CPU tensor β€” never touches CUDA.
self._device = fn(torch.zeros(1, device="cpu")).device
return nn.Module._apply(self, fn)
torchmetrics.Metric._apply = _safe_apply
print("[tft] torchmetrics._apply patched for CPU-only inference")
except Exception as e:
print(f"[tft] torchmetrics patch skipped: {e}")
def load_model(model_path: str):
"""Load and cache the pytorch-forecasting TFT from a Lightning checkpoint."""
global _model, _model_path_cached
if _model is not None and _model_path_cached == model_path:
return _model
_maybe_download("tft_stock.ckpt", model_path)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
# Patch torchmetrics BEFORE importing pytorch_forecasting so the patched
# _apply is in place when Lightning restores metric state from the checkpoint.
_patch_torchmetrics_cpu()
from pytorch_forecasting import TemporalFusionTransformer
# Callable map_location: moves all tensors to CPU AND skips Lightning's
# isinstance(map_location, (str, torch.device)) branch that would call
# model.to(map_location) β€” which would re-trigger the CUDA error.
model = TemporalFusionTransformer.load_from_checkpoint(
model_path,
map_location=lambda storage, loc: storage.cpu(),
)
model.eval()
_model = model
_model_path_cached = model_path
print(f"[tft] Loaded pytorch-forecasting TFT from {model_path}")
return model
def load_dataset_params(params_path: str) -> dict:
"""Load and cache the TimeSeriesDataSet parameters saved during Colab training."""
global _ds_params, _ds_params_path_cached
if _ds_params is not None and _ds_params_path_cached == params_path:
return _ds_params
_maybe_download("dataset_params.pt", params_path)
if not os.path.exists(params_path):
raise FileNotFoundError(f"Dataset params not found: {params_path}")
params = torch.load(params_path, map_location="cpu", weights_only=False)
_ds_params = params
_ds_params_path_cached = params_path
print(f"[tft] Loaded dataset params from {params_path}")
return params
# ── Inference DataFrame builder ───────────────────────────────────────────────
def _build_inference_df(
closes: np.ndarray,
volumes: np.ndarray,
timestamps: np.ndarray,
symbol: str,
) -> pd.DataFrame:
"""
Build a DataFrame with ENCODER_LENGTH encoder rows + FORECAST_HORIZON future rows.
The encoder rows contain real feature values; future rows have only known reals
(day/month cyclicals) β€” the decoder does not use unknown future reals.
"""
features = build_features(closes, volumes, timestamps) # (T, 11)
if len(features) < ENCODER_LENGTH:
raise ValueError(f"Need at least {ENCODER_LENGTH} candles, got {len(features)}")
features = features[-ENCODER_LENGTH:]
ts_slice = timestamps[-len(features):]
# Timestamps β†’ Python datetimes
dates = [datetime.utcfromtimestamp(int(ts)) for ts in ts_slice]
# Build encoder rows
rows = []
for i, (feat_row, dt) in enumerate(zip(features, dates)):
row: dict = {
"ticker": symbol,
"time_idx": i,
"date": dt,
}
for col in UNKNOWN_REALS + KNOWN_REALS:
row[col] = float(feat_row[_FEAT_IDX[col]])
rows.append(row)
encoder_df = pd.DataFrame(rows)
# Build future decoder rows (known reals computed from calendar)
last_date = dates[-1]
future_rows = []
for i in range(1, FORECAST_HORIZON + 1):
future_date = last_date + timedelta(days=i)
future_rows.append({
"ticker": symbol,
"time_idx": ENCODER_LENGTH + i - 1,
"date": future_date,
# Unknown reals: placeholder values (not used in decoder future steps)
TARGET: 0.0,
"volume_norm": 0.0,
"rsi": 0.5,
"macd_norm": 0.0,
"bb_width": 0.0,
"atr_norm": 0.0,
"obv_norm": 0.0,
# Known reals: actual calendar features
"day_sin": float(np.sin(2 * np.pi * future_date.weekday() / 5)),
"day_cos": float(np.cos(2 * np.pi * future_date.weekday() / 5)),
"month_sin": float(np.sin(2 * np.pi * future_date.month / 12)),
"month_cos": float(np.cos(2 * np.pi * future_date.month / 12)),
})
return pd.concat([encoder_df, pd.DataFrame(future_rows)], ignore_index=True)
# ── Inference ─────────────────────────────────────────────────────────────────
def predict_quantiles(
closes: np.ndarray,
volumes: np.ndarray,
timestamps: np.ndarray,
days: int,
model_path: str,
dataset_params_path: Optional[str] = None,
symbol: str = "UNKNOWN",
) -> dict:
"""
Run pytorch-forecasting TFT inference for `days` forecast horizon.
Returns quantile predictions as price levels (denormalized).
"""
if dataset_params_path is None:
dataset_params_path = model_path.replace("tft_stock.ckpt", "dataset_params.pt")
model = load_model(model_path)
ds_params = load_dataset_params(dataset_params_path)
days = max(1, min(days, FORECAST_HORIZON))
current_price = float(closes[-1])
roll_mean = float(np.mean(closes[-30:]))
roll_std = float(np.std(closes[-30:])) or 1.0
# Build inference DataFrame
full_df = _build_inference_df(closes, volumes, timestamps, symbol)
# Reconstruct TimeSeriesDataSet from training-time parameters.
# from_parameters() reuses the fitted categorical encoder (ticker β†’ int),
# so unknown tickers fall back to the UNK embedding gracefully.
from pytorch_forecasting import TimeSeriesDataSet
pred_ds = TimeSeriesDataSet.from_parameters(
ds_params,
full_df,
predict=True, # one sample per group, from the end of data
stop_randomization=True,
min_encoder_length=ENCODER_LENGTH,
max_encoder_length=ENCODER_LENGTH,
min_prediction_length=FORECAST_HORIZON,
max_prediction_length=FORECAST_HORIZON,
min_prediction_idx=None,
)
pred_dl = pred_ds.to_dataloader(train=False, batch_size=1, num_workers=0)
# Predict β€” returns tensor of shape (1, FORECAST_HORIZON, N_QUANTILES)
with torch.no_grad():
raw = model.predict(pred_dl, mode="quantiles", return_x=False)
# Handle both tensor and list returns
if isinstance(raw, torch.Tensor):
preds = raw.squeeze(0).cpu().numpy() # (FORECAST_HORIZON, 3)
else:
preds = np.array(raw[0]) # (FORECAST_HORIZON, 3)
preds = preds[:days] # slice to requested horizon
# Denormalize rolling z-score β†’ price levels
q10 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 0]]
q50 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 1]]
q90 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 2]]
# Enforce monotonic bounds (q10 ≀ q50 ≀ q90)
for i in range(days):
q10[i] = min(q10[i], q50[i])
q90[i] = max(q90[i], q50[i])
final_price = q50[-1]
trend = (
"bullish" if final_price > current_price * 1.005
else "bearish" if final_price < current_price * 0.995
else "sideways"
)
change_pct = (final_price - current_price) / current_price * 100
return {
"method": "tft",
"predictions": q50,
"lower_bound": q10,
"upper_bound": q90,
"target_price": final_price,
"trend": trend,
"change_pct": round(change_pct, 2),
"confidence": 72,
"support": round(min(q10), 2),
"resistance": round(max(q90), 2),
"feature_importance": {}, # TFT attention weights available via interpret_output() if needed
}