Spaces:
Sleeping
Sleeping
| """ | |
| pytorch-forecasting TFT inference for IDX stock price prediction. | |
| Loads from Lightning checkpoint (.ckpt) produced by train_colab.py. | |
| Uses pytorch-forecasting's TimeSeriesDataSet + TemporalFusionTransformer. | |
| """ | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from datetime import datetime, timedelta | |
| from typing import Optional | |
| from app.services.feature_engineer import SEQUENCE_LEN, FEATURE_COLS, build_features | |
| FORECAST_HORIZON = 30 | |
| ENCODER_LENGTH = SEQUENCE_LEN # 60 | |
| QUANTILES = [0.1, 0.5, 0.9] | |
| N_QUANTILES = len(QUANTILES) | |
| TARGET = "close_norm" | |
| KNOWN_REALS = ["day_sin", "day_cos", "month_sin", "month_cos"] | |
| UNKNOWN_REALS = ["close_norm", "volume_norm", "rsi", "macd_norm", "bb_width", "atr_norm", "obv_norm"] | |
| # Column index lookup for build_features() output | |
| _FEAT_IDX = {col: i for i, col in enumerate(FEATURE_COLS)} | |
| # ββ Model / params caching ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model = None | |
| _model_path_cached: Optional[str] = None | |
| _ds_params: Optional[dict] = None | |
| _ds_params_path_cached: Optional[str] = None | |
| def _maybe_download(filename: str, local_path: str) -> bool: | |
| """Download a file from HF Hub if not present locally.""" | |
| if os.path.exists(local_path): | |
| return True | |
| import app.config as cfg | |
| if not cfg.MODEL_REPO or not cfg.HF_TOKEN: | |
| return False | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| local = hf_hub_download( | |
| repo_id=cfg.MODEL_REPO, | |
| filename=filename, | |
| token=cfg.HF_TOKEN, | |
| local_dir=os.path.dirname(local_path), | |
| ) | |
| if local != local_path: | |
| import shutil | |
| shutil.copy2(local, local_path) | |
| return os.path.exists(local_path) | |
| except Exception as e: | |
| print(f"[tft] Could not download {filename} from HF Hub: {e}") | |
| return False | |
| def _patch_torchmetrics_cpu(): | |
| """Patch torchmetrics.Metric._apply to avoid CUDA errors on CPU-only servers. | |
| When a GPU-trained checkpoint is loaded on CPU-only hardware, the torchmetrics | |
| Metric._apply method does `fn(torch.zeros(1, device=self.device))` where | |
| self.device may still be "cuda:0" from the checkpoint. We replace that with | |
| a safe CPU probe so the destination device is inferred without touching CUDA. | |
| """ | |
| try: | |
| import torchmetrics | |
| import torch.nn as nn | |
| _orig = torchmetrics.Metric._apply | |
| def _safe_apply(self, fn): | |
| # Probe destination device via a CPU tensor β never touches CUDA. | |
| self._device = fn(torch.zeros(1, device="cpu")).device | |
| return nn.Module._apply(self, fn) | |
| torchmetrics.Metric._apply = _safe_apply | |
| print("[tft] torchmetrics._apply patched for CPU-only inference") | |
| except Exception as e: | |
| print(f"[tft] torchmetrics patch skipped: {e}") | |
| def load_model(model_path: str): | |
| """Load and cache the pytorch-forecasting TFT from a Lightning checkpoint.""" | |
| global _model, _model_path_cached | |
| if _model is not None and _model_path_cached == model_path: | |
| return _model | |
| _maybe_download("tft_stock.ckpt", model_path) | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model checkpoint not found: {model_path}") | |
| # Patch torchmetrics BEFORE importing pytorch_forecasting so the patched | |
| # _apply is in place when Lightning restores metric state from the checkpoint. | |
| _patch_torchmetrics_cpu() | |
| from pytorch_forecasting import TemporalFusionTransformer | |
| # Callable map_location: moves all tensors to CPU AND skips Lightning's | |
| # isinstance(map_location, (str, torch.device)) branch that would call | |
| # model.to(map_location) β which would re-trigger the CUDA error. | |
| model = TemporalFusionTransformer.load_from_checkpoint( | |
| model_path, | |
| map_location=lambda storage, loc: storage.cpu(), | |
| ) | |
| model.eval() | |
| _model = model | |
| _model_path_cached = model_path | |
| print(f"[tft] Loaded pytorch-forecasting TFT from {model_path}") | |
| return model | |
| def load_dataset_params(params_path: str) -> dict: | |
| """Load and cache the TimeSeriesDataSet parameters saved during Colab training.""" | |
| global _ds_params, _ds_params_path_cached | |
| if _ds_params is not None and _ds_params_path_cached == params_path: | |
| return _ds_params | |
| _maybe_download("dataset_params.pt", params_path) | |
| if not os.path.exists(params_path): | |
| raise FileNotFoundError(f"Dataset params not found: {params_path}") | |
| params = torch.load(params_path, map_location="cpu", weights_only=False) | |
| _ds_params = params | |
| _ds_params_path_cached = params_path | |
| print(f"[tft] Loaded dataset params from {params_path}") | |
| return params | |
| # ββ Inference DataFrame builder βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_inference_df( | |
| closes: np.ndarray, | |
| volumes: np.ndarray, | |
| timestamps: np.ndarray, | |
| symbol: str, | |
| ) -> pd.DataFrame: | |
| """ | |
| Build a DataFrame with ENCODER_LENGTH encoder rows + FORECAST_HORIZON future rows. | |
| The encoder rows contain real feature values; future rows have only known reals | |
| (day/month cyclicals) β the decoder does not use unknown future reals. | |
| """ | |
| features = build_features(closes, volumes, timestamps) # (T, 11) | |
| if len(features) < ENCODER_LENGTH: | |
| raise ValueError(f"Need at least {ENCODER_LENGTH} candles, got {len(features)}") | |
| features = features[-ENCODER_LENGTH:] | |
| ts_slice = timestamps[-len(features):] | |
| # Timestamps β Python datetimes | |
| dates = [datetime.utcfromtimestamp(int(ts)) for ts in ts_slice] | |
| # Build encoder rows | |
| rows = [] | |
| for i, (feat_row, dt) in enumerate(zip(features, dates)): | |
| row: dict = { | |
| "ticker": symbol, | |
| "time_idx": i, | |
| "date": dt, | |
| } | |
| for col in UNKNOWN_REALS + KNOWN_REALS: | |
| row[col] = float(feat_row[_FEAT_IDX[col]]) | |
| rows.append(row) | |
| encoder_df = pd.DataFrame(rows) | |
| # Build future decoder rows (known reals computed from calendar) | |
| last_date = dates[-1] | |
| future_rows = [] | |
| for i in range(1, FORECAST_HORIZON + 1): | |
| future_date = last_date + timedelta(days=i) | |
| future_rows.append({ | |
| "ticker": symbol, | |
| "time_idx": ENCODER_LENGTH + i - 1, | |
| "date": future_date, | |
| # Unknown reals: placeholder values (not used in decoder future steps) | |
| TARGET: 0.0, | |
| "volume_norm": 0.0, | |
| "rsi": 0.5, | |
| "macd_norm": 0.0, | |
| "bb_width": 0.0, | |
| "atr_norm": 0.0, | |
| "obv_norm": 0.0, | |
| # Known reals: actual calendar features | |
| "day_sin": float(np.sin(2 * np.pi * future_date.weekday() / 5)), | |
| "day_cos": float(np.cos(2 * np.pi * future_date.weekday() / 5)), | |
| "month_sin": float(np.sin(2 * np.pi * future_date.month / 12)), | |
| "month_cos": float(np.cos(2 * np.pi * future_date.month / 12)), | |
| }) | |
| return pd.concat([encoder_df, pd.DataFrame(future_rows)], ignore_index=True) | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_quantiles( | |
| closes: np.ndarray, | |
| volumes: np.ndarray, | |
| timestamps: np.ndarray, | |
| days: int, | |
| model_path: str, | |
| dataset_params_path: Optional[str] = None, | |
| symbol: str = "UNKNOWN", | |
| ) -> dict: | |
| """ | |
| Run pytorch-forecasting TFT inference for `days` forecast horizon. | |
| Returns quantile predictions as price levels (denormalized). | |
| """ | |
| if dataset_params_path is None: | |
| dataset_params_path = model_path.replace("tft_stock.ckpt", "dataset_params.pt") | |
| model = load_model(model_path) | |
| ds_params = load_dataset_params(dataset_params_path) | |
| days = max(1, min(days, FORECAST_HORIZON)) | |
| current_price = float(closes[-1]) | |
| roll_mean = float(np.mean(closes[-30:])) | |
| roll_std = float(np.std(closes[-30:])) or 1.0 | |
| # Build inference DataFrame | |
| full_df = _build_inference_df(closes, volumes, timestamps, symbol) | |
| # Reconstruct TimeSeriesDataSet from training-time parameters. | |
| # from_parameters() reuses the fitted categorical encoder (ticker β int), | |
| # so unknown tickers fall back to the UNK embedding gracefully. | |
| from pytorch_forecasting import TimeSeriesDataSet | |
| pred_ds = TimeSeriesDataSet.from_parameters( | |
| ds_params, | |
| full_df, | |
| predict=True, # one sample per group, from the end of data | |
| stop_randomization=True, | |
| min_encoder_length=ENCODER_LENGTH, | |
| max_encoder_length=ENCODER_LENGTH, | |
| min_prediction_length=FORECAST_HORIZON, | |
| max_prediction_length=FORECAST_HORIZON, | |
| min_prediction_idx=None, | |
| ) | |
| pred_dl = pred_ds.to_dataloader(train=False, batch_size=1, num_workers=0) | |
| # Predict β returns tensor of shape (1, FORECAST_HORIZON, N_QUANTILES) | |
| with torch.no_grad(): | |
| raw = model.predict(pred_dl, mode="quantiles", return_x=False) | |
| # Handle both tensor and list returns | |
| if isinstance(raw, torch.Tensor): | |
| preds = raw.squeeze(0).cpu().numpy() # (FORECAST_HORIZON, 3) | |
| else: | |
| preds = np.array(raw[0]) # (FORECAST_HORIZON, 3) | |
| preds = preds[:days] # slice to requested horizon | |
| # Denormalize rolling z-score β price levels | |
| q10 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 0]] | |
| q50 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 1]] | |
| q90 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 2]] | |
| # Enforce monotonic bounds (q10 β€ q50 β€ q90) | |
| for i in range(days): | |
| q10[i] = min(q10[i], q50[i]) | |
| q90[i] = max(q90[i], q50[i]) | |
| final_price = q50[-1] | |
| trend = ( | |
| "bullish" if final_price > current_price * 1.005 | |
| else "bearish" if final_price < current_price * 0.995 | |
| else "sideways" | |
| ) | |
| change_pct = (final_price - current_price) / current_price * 100 | |
| return { | |
| "method": "tft", | |
| "predictions": q50, | |
| "lower_bound": q10, | |
| "upper_bound": q90, | |
| "target_price": final_price, | |
| "trend": trend, | |
| "change_pct": round(change_pct, 2), | |
| "confidence": 72, | |
| "support": round(min(q10), 2), | |
| "resistance": round(max(q90), 2), | |
| "feature_importance": {}, # TFT attention weights available via interpret_output() if needed | |
| } | |