""" TFT-ASRO Inference Pipeline. Produces live multi-quantile predictions by: 1. Assembling the latest feature vector from all data sources 2. Running through the trained TFT model 3. Formatting the output as a structured prediction dict Designed to run in parallel with the existing XGBoost inference pipeline for A/B comparison. """ from __future__ import annotations import json import logging from datetime import datetime, timedelta, timezone from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional import os import warnings import numpy as np import pandas as pd warnings.filterwarnings( "ignore", message="X does not have valid feature names", category=UserWarning, module="sklearn", ) from deep_learning.config import TFTASROConfig, get_tft_config logger = logging.getLogger(__name__) class TFTPredictor: """ Stateful predictor that holds the loaded model and PCA transformer. Thread-safe for read-only inference (no internal mutation after init). """ def __init__( self, checkpoint_path: Optional[str] = None, cfg: Optional[TFTASROConfig] = None, ): self.cfg = cfg or get_tft_config() self._checkpoint_path = checkpoint_path or self.cfg.training.best_model_path self._model = None self._pca = None self._hub_checked = False def _ensure_local_artifacts(self) -> None: """Download checkpoint from HF Hub if not present locally.""" if self._hub_checked: return self._hub_checked = True if Path(self._checkpoint_path).exists(): return try: from deep_learning.models.hub import download_tft_artifacts tft_dir = Path(self._checkpoint_path).parent downloaded = download_tft_artifacts( local_dir=tft_dir, repo_id=self.cfg.training.hf_model_repo, ) if downloaded: logger.info("TFT checkpoint downloaded from HF Hub") else: logger.warning("TFT checkpoint not available on HF Hub") except Exception as exc: logger.warning("HF Hub download attempt failed: %s", exc) @property def model(self): if self._model is None: self._ensure_local_artifacts() if not Path(self._checkpoint_path).exists(): raise FileNotFoundError( f"TFT checkpoint not found: {self._checkpoint_path}" ) from deep_learning.models.tft_copper import load_tft_model self._model = load_tft_model(self._checkpoint_path) return self._model @property def pca(self): if self._pca is None: self._ensure_local_artifacts() pca_path = self.cfg.embedding.pca_model_path if Path(pca_path).exists(): from deep_learning.data.embeddings import load_pca self._pca = load_pca(pca_path) return self._pca def predict(self, session, symbol: str = "HG=F") -> Dict[str, Any]: """ Generate a TFT-ASRO prediction for the given symbol. Returns a dict with: - predicted_return_median, q10, q90 - predicted_price_median, q10, q90 - confidence_band_96 - volatility_estimate - quantiles (all 7) - model_info """ from deep_learning.data.feature_store import build_tft_dataframe from deep_learning.data.dataset import build_datasets, create_dataloaders from deep_learning.models.tft_copper import format_prediction from pytorch_forecasting import TimeSeriesDataSet master_df, tv_unknown, tv_known, target_cols = build_tft_dataframe(session, self.cfg) last_known_price = self._get_last_price(session, symbol) encoder_length = self.cfg.model.max_encoder_length prediction_length = self.cfg.model.max_prediction_length recent = master_df.tail(encoder_length + prediction_length).copy() if len(recent) < encoder_length + 1: return {"error": f"Insufficient data: {len(recent)} rows, need {encoder_length + 1}"} recent["time_idx"] = np.arange(len(recent)) target = target_cols[0] if target_cols else "target" try: ds = TimeSeriesDataSet( recent, time_idx="time_idx", target=target, group_ids=["group_id"], max_encoder_length=encoder_length, max_prediction_length=prediction_length, time_varying_unknown_reals=tv_unknown, time_varying_known_reals=tv_known, static_categoricals=["group_id"], add_relative_time_idx=True, add_target_scales=True, add_encoder_length=True, allow_missing_timesteps=True, ) except Exception as exc: logger.error("Failed to create inference dataset: %s", exc) return {"error": str(exc)} _nw = 0 if os.name == "nt" else 2 dl = ds.to_dataloader(train=False, batch_size=1, num_workers=_nw) try: import torch # mode="quantiles" returns a plain Tensor (n_samples, pred_len, n_quantiles) # Avoids the inhomogeneous-shape error from mode="raw" which returns a # NamedTuple; np.array() cannot convert that to a uniform array. pred_tensor = self.model.predict(dl, mode="quantiles") if isinstance(pred_tensor, torch.Tensor): pred_np = pred_tensor.cpu().numpy() else: pred_np = np.array(pred_tensor) # Take first sample: (pred_len, n_quantiles) if pred_np.ndim == 3: pred_for_format = pred_np[0] elif pred_np.ndim == 2: pred_for_format = pred_np else: pred_for_format = pred_np.reshape(1, -1) except Exception as exc: logger.error("TFT prediction failed: %s", exc) return {"error": str(exc)} result = format_prediction( pred_for_format, quantiles=self.cfg.model.quantiles, baseline_price=last_known_price, ) result["model_info"] = { "type": "TFT-ASRO", "checkpoint": self._checkpoint_path, "encoder_length": encoder_length, "prediction_length": prediction_length, "n_features_unknown": len(tv_unknown), "n_features_known": len(tv_known), } result["generated_at"] = datetime.now(timezone.utc).isoformat() result["symbol"] = symbol return result def _get_last_price(self, session, symbol: str) -> float: """Fetch the latest close price from the database.""" from app.models import PriceBar row = ( session.query(PriceBar.close) .filter(PriceBar.symbol == symbol) .order_by(PriceBar.date.desc()) .first() ) return float(row.close) if row else 1.0 def get_model_metadata(self, session) -> Optional[Dict]: """Load persisted TFT model metadata from DB.""" from app.models import TFTModelMetadata meta = ( session.query(TFTModelMetadata) .filter(TFTModelMetadata.symbol == self.cfg.feature_store.target_symbol) .first() ) if meta is None: return None return { "symbol": meta.symbol, "trained_at": meta.trained_at.isoformat() if meta.trained_at else None, "checkpoint_path": meta.checkpoint_path, "config": json.loads(meta.config_json) if meta.config_json else {}, "metrics": json.loads(meta.metrics_json) if meta.metrics_json else {}, } # --------------------------------------------------------------------------- # Module-level convenience # --------------------------------------------------------------------------- _predictor: Optional[TFTPredictor] = None def get_tft_predictor(cfg: Optional[TFTASROConfig] = None) -> TFTPredictor: """Singleton-style access to the TFT predictor.""" global _predictor if _predictor is None: _predictor = TFTPredictor(cfg=cfg) return _predictor def generate_tft_analysis(session, symbol: str = "HG=F") -> Dict[str, Any]: """ High-level API for generating a TFT-ASRO analysis report. Designed to mirror the interface of the existing ``app.inference.generate_analysis_report``. """ predictor = get_tft_predictor() prediction = predictor.predict(session, symbol) if "error" in prediction: return prediction metadata = predictor.get_model_metadata(session) # Direction based on T+1 (most reliable signal) median_ret = prediction.get("predicted_return_median", 0) if median_ret > 0.005: direction = "BULLISH" elif median_ret < -0.005: direction = "BEARISH" else: direction = "NEUTRAL" # Weekly trend based on T+5 (end-of-horizon) weekly_ret = prediction.get("weekly_return", median_ret) if weekly_ret > 0.005: weekly_trend = "BULLISH" elif weekly_ret < -0.005: weekly_trend = "BEARISH" else: weekly_trend = "NEUTRAL" vol = prediction.get("volatility_estimate", 0) if vol > 0.02: risk_level = "HIGH" elif vol > 0.01: risk_level = "MEDIUM" else: risk_level = "LOW" import math def _sanitize_floats(obj: Any) -> Any: if isinstance(obj, float): if math.isnan(obj) or math.isinf(obj): return None return obj elif isinstance(obj, dict): return {k: _sanitize_floats(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return type(obj)(_sanitize_floats(v) for v in obj) return obj raw_result = { "symbol": symbol, "model_type": "TFT-ASRO", "direction": direction, "weekly_trend": weekly_trend, "risk_level": risk_level, "prediction": prediction, "model_metadata": metadata, "generated_at": datetime.now(timezone.utc).isoformat(), } return _sanitize_floats(raw_result)