Spaces:
Sleeping
Sleeping
fix: track app/models/ package, ignore only top-level /models/
Browse files- .gitignore +1 -1
- app/models/__init__.py +0 -0
- app/models/ddg_da.py +191 -0
- app/models/embeddings.py +114 -0
- app/models/tft_predictor.py +249 -0
.gitignore
CHANGED
|
@@ -2,7 +2,7 @@ venv/
|
|
| 2 |
__pycache__/
|
| 3 |
*.pyc
|
| 4 |
.env
|
| 5 |
-
models/
|
| 6 |
*.ckpt
|
| 7 |
*.pt
|
| 8 |
checkpoints/
|
|
|
|
| 2 |
__pycache__/
|
| 3 |
*.pyc
|
| 4 |
.env
|
| 5 |
+
/models/
|
| 6 |
*.ckpt
|
| 7 |
*.pt
|
| 8 |
checkpoints/
|
app/models/__init__.py
ADDED
|
File without changes
|
app/models/ddg_da.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DDG-DA: Data Distribution Generation for Predictable Concept Drift Adaptation.
|
| 3 |
+
|
| 4 |
+
Implements:
|
| 5 |
+
1. DriftPredictorMLP β small MLP that predicts the next distribution snapshot
|
| 6 |
+
2. DDGDAPredictor β orchestrator: drift detection + drift score reporting
|
| 7 |
+
|
| 8 |
+
Note: DDG-DA head fine-tuning is model-architecture-specific and is handled
|
| 9 |
+
separately. This module provides drift detection that works with any TFT backend.
|
| 10 |
+
|
| 11 |
+
Reference: "DDG-DA: Data Distribution Generation for Predictable Concept Drift
|
| 12 |
+
Adaptation" (Data-Centric AI workshop).
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
from app.services.concept_drift import (
|
| 23 |
+
SNAPSHOT_DIM,
|
| 24 |
+
K_HISTORY,
|
| 25 |
+
SNAPSHOT_WINDOW,
|
| 26 |
+
DriftState,
|
| 27 |
+
compute_snapshot,
|
| 28 |
+
extract_snapshots_from_series,
|
| 29 |
+
compute_drift_score,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ββ 1. Drift Predictor MLP ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
|
| 35 |
+
class DriftPredictorMLP(nn.Module):
|
| 36 |
+
"""
|
| 37 |
+
Predicts the NEXT distribution snapshot from the last K snapshots.
|
| 38 |
+
|
| 39 |
+
Input: (batch, K * SNAPSHOT_DIM) = (batch, 8 * 44 = 352)
|
| 40 |
+
Output: (batch, SNAPSHOT_DIM) = (batch, 44)
|
| 41 |
+
|
| 42 |
+
~51K parameters, ~200KB model file on disk.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, k_history: int = K_HISTORY, snapshot_dim: int = SNAPSHOT_DIM, hidden: int = 128):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.input_dim = k_history * snapshot_dim
|
| 48 |
+
self.net = nn.Sequential(
|
| 49 |
+
nn.Linear(self.input_dim, hidden),
|
| 50 |
+
nn.ELU(),
|
| 51 |
+
nn.Dropout(0.1),
|
| 52 |
+
nn.Linear(hidden, snapshot_dim),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
return self.net(x)
|
| 57 |
+
|
| 58 |
+
def predict_next(self, history_snapshots: np.ndarray) -> np.ndarray:
|
| 59 |
+
"""
|
| 60 |
+
Predict next snapshot from (K, 44) history β (44,) prediction.
|
| 61 |
+
Pads with zeros if fewer than K history snapshots are available.
|
| 62 |
+
"""
|
| 63 |
+
k = history_snapshots.shape[0]
|
| 64 |
+
required = self.input_dim // SNAPSHOT_DIM
|
| 65 |
+
if k < required:
|
| 66 |
+
padding = np.zeros((required - k, SNAPSHOT_DIM), dtype=np.float32)
|
| 67 |
+
history_snapshots = np.concatenate([padding, history_snapshots], axis=0)
|
| 68 |
+
x = torch.tensor(history_snapshots.flatten()[None], dtype=torch.float32) # (1, 352)
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
out = self.net(x).squeeze(0).numpy() # (44,)
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ββ 2. DDG-DA Predictor (Orchestrator) βββββββββββββββββββββββββββββββββββββββ
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class DDGDAPredictor:
|
| 78 |
+
"""
|
| 79 |
+
Orchestrates concept drift detection.
|
| 80 |
+
|
| 81 |
+
- Measures current feature distribution (snapshot)
|
| 82 |
+
- Scores drift vs. historical reference
|
| 83 |
+
- Reports drift_score and drift_detected for confidence adjustment
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
model_path: str,
|
| 89 |
+
k_history: int = K_HISTORY,
|
| 90 |
+
):
|
| 91 |
+
self.mlp = DriftPredictorMLP(k_history=k_history)
|
| 92 |
+
self.k_history = k_history
|
| 93 |
+
self._load_mlp(model_path)
|
| 94 |
+
|
| 95 |
+
def _load_mlp(self, model_path: str) -> None:
|
| 96 |
+
if os.path.exists(model_path):
|
| 97 |
+
try:
|
| 98 |
+
state = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 99 |
+
self.mlp.load_state_dict(state)
|
| 100 |
+
self.mlp.eval()
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"[ddg_da] Could not load MLP weights: {e}")
|
| 103 |
+
|
| 104 |
+
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
+
|
| 106 |
+
def adapt(self, features: np.ndarray) -> DriftState:
|
| 107 |
+
"""
|
| 108 |
+
Assess concept drift from the current feature distribution.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
features: (T, N_FEATURES) feature matrix for the current symbol.
|
| 112 |
+
Should cover at least SNAPSHOT_WINDOW rows.
|
| 113 |
+
Returns:
|
| 114 |
+
DriftState with drift_score, drift_detected, and predicted_next_snapshot.
|
| 115 |
+
"""
|
| 116 |
+
if len(features) < SNAPSHOT_WINDOW:
|
| 117 |
+
return DriftState(
|
| 118 |
+
snapshot=np.zeros(SNAPSHOT_DIM, dtype=np.float32),
|
| 119 |
+
drift_score=0.0,
|
| 120 |
+
drift_detected=False,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
current_snap = compute_snapshot(features, window=SNAPSHOT_WINDOW)
|
| 124 |
+
all_snaps = extract_snapshots_from_series(features, window=SNAPSHOT_WINDOW)
|
| 125 |
+
if len(all_snaps) < 2:
|
| 126 |
+
return DriftState(snapshot=current_snap, drift_score=0.0, drift_detected=False)
|
| 127 |
+
|
| 128 |
+
drift_score, drift_detected = compute_drift_score(current_snap, all_snaps[:-1])
|
| 129 |
+
|
| 130 |
+
history_for_mlp = all_snaps[-self.k_history:]
|
| 131 |
+
predicted_next = self.mlp.predict_next(history_for_mlp)
|
| 132 |
+
|
| 133 |
+
return DriftState(
|
| 134 |
+
snapshot=current_snap,
|
| 135 |
+
drift_score=drift_score,
|
| 136 |
+
drift_detected=drift_detected,
|
| 137 |
+
predicted_next_snapshot=predicted_next,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ββ Module-level loader with caching βββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
|
| 143 |
+
_ddg_da: Optional[DDGDAPredictor] = None
|
| 144 |
+
_ddg_da_path_cached: Optional[str] = None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _maybe_download_ddg_da(model_path: str) -> bool:
|
| 148 |
+
"""Download ddg_da.pt from HF Hub if not present locally."""
|
| 149 |
+
if os.path.exists(model_path):
|
| 150 |
+
return True
|
| 151 |
+
import app.config as cfg
|
| 152 |
+
if not cfg.MODEL_REPO or not cfg.HF_TOKEN:
|
| 153 |
+
return False
|
| 154 |
+
try:
|
| 155 |
+
from huggingface_hub import hf_hub_download
|
| 156 |
+
local = hf_hub_download(
|
| 157 |
+
repo_id=cfg.MODEL_REPO,
|
| 158 |
+
filename="ddg_da.pt",
|
| 159 |
+
token=cfg.HF_TOKEN,
|
| 160 |
+
local_dir=os.path.dirname(model_path),
|
| 161 |
+
)
|
| 162 |
+
if local != model_path:
|
| 163 |
+
import shutil
|
| 164 |
+
shutil.copy2(local, model_path)
|
| 165 |
+
return os.path.exists(model_path)
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"[ddg_da] Could not download from HF Hub: {e}")
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def load_ddg_da(model_path: str) -> Optional[DDGDAPredictor]:
|
| 172 |
+
"""
|
| 173 |
+
Load (and cache) the DDG-DA drift predictor.
|
| 174 |
+
Returns None gracefully if ddg_da.pt is absent β base TFT still works.
|
| 175 |
+
"""
|
| 176 |
+
global _ddg_da, _ddg_da_path_cached
|
| 177 |
+
if _ddg_da is not None and _ddg_da_path_cached == model_path:
|
| 178 |
+
return _ddg_da
|
| 179 |
+
|
| 180 |
+
_maybe_download_ddg_da(model_path)
|
| 181 |
+
if not os.path.exists(model_path):
|
| 182 |
+
return None
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
predictor = DDGDAPredictor(model_path=model_path)
|
| 186 |
+
_ddg_da = predictor
|
| 187 |
+
_ddg_da_path_cached = model_path
|
| 188 |
+
return _ddg_da
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"[ddg_da] Failed to initialize DDGDAPredictor: {e}")
|
| 191 |
+
return None
|
app/models/embeddings.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Stock similarity via 16-dimensional feature embeddings.
|
| 3 |
+
Stored in Supabase as JSONB float arrays (no pgvector required).
|
| 4 |
+
Cosine similarity computed here for the /similar endpoint.
|
| 5 |
+
"""
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Optional
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
EMBED_DIM = 16
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _safe_norm(v: np.ndarray) -> float:
|
| 15 |
+
return float(np.std(v)) or 1.0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def compute_embedding(
|
| 19 |
+
closes: np.ndarray,
|
| 20 |
+
volumes: np.ndarray,
|
| 21 |
+
sector_id: int = 0,
|
| 22 |
+
) -> np.ndarray:
|
| 23 |
+
"""
|
| 24 |
+
16-dim embedding per stock:
|
| 25 |
+
[0] sector (normalised 0-1)
|
| 26 |
+
[1] annualised return (capped Β±100%)
|
| 27 |
+
[2] annualised volatility
|
| 28 |
+
[3-6] return autocorrelations lag 1-4
|
| 29 |
+
[7] beta proxy (corr with overall index, approx)
|
| 30 |
+
[8] volume z-score trend (avg recent vs avg older)
|
| 31 |
+
[9] max drawdown
|
| 32 |
+
[10] skewness of returns
|
| 33 |
+
[11] kurtosis of returns
|
| 34 |
+
[12-15] rolling return quartiles (q25, q50, q75, q90)
|
| 35 |
+
"""
|
| 36 |
+
vec = np.zeros(EMBED_DIM, dtype=np.float32)
|
| 37 |
+
|
| 38 |
+
if len(closes) < 30:
|
| 39 |
+
return vec
|
| 40 |
+
|
| 41 |
+
ret = np.diff(np.log(closes + 1e-9))
|
| 42 |
+
T = len(ret)
|
| 43 |
+
|
| 44 |
+
# [0] sector normalised 0-1 (max ~12 sectors on IDX)
|
| 45 |
+
vec[0] = min(sector_id / 12.0, 1.0)
|
| 46 |
+
|
| 47 |
+
# [1] annualised return, capped
|
| 48 |
+
ann_ret = np.mean(ret) * 252
|
| 49 |
+
vec[1] = float(np.clip(ann_ret, -1.0, 1.0))
|
| 50 |
+
|
| 51 |
+
# [2] annualised volatility
|
| 52 |
+
vec[2] = float(np.std(ret) * np.sqrt(252))
|
| 53 |
+
|
| 54 |
+
# [3-6] autocorrelations lag 1-4
|
| 55 |
+
for lag in range(1, 5):
|
| 56 |
+
if T > lag + 1:
|
| 57 |
+
corr = float(np.corrcoef(ret[:-lag], ret[lag:])[0, 1])
|
| 58 |
+
vec[2 + lag] = corr if not np.isnan(corr) else 0.0
|
| 59 |
+
|
| 60 |
+
# [7] beta proxy: correlation with its own 20-day rolling avg (smoother = lower beta)
|
| 61 |
+
s = pd.Series(closes)
|
| 62 |
+
trend = s.rolling(20, min_periods=5).mean().dropna().values
|
| 63 |
+
if len(trend) > 10:
|
| 64 |
+
corr = float(np.corrcoef(closes[-len(trend):], trend)[0, 1])
|
| 65 |
+
vec[7] = corr if not np.isnan(corr) else 0.0
|
| 66 |
+
|
| 67 |
+
# [8] volume trend: mean of recent 20 vs older 20
|
| 68 |
+
if len(volumes) >= 40:
|
| 69 |
+
recent = float(np.mean(volumes[-20:]))
|
| 70 |
+
older = float(np.mean(volumes[-40:-20])) or 1.0
|
| 71 |
+
vec[8] = float(np.clip((recent / older) - 1, -1, 1))
|
| 72 |
+
|
| 73 |
+
# [9] max drawdown
|
| 74 |
+
cum = np.cumprod(1 + ret)
|
| 75 |
+
running_max = np.maximum.accumulate(cum)
|
| 76 |
+
drawdowns = (cum - running_max) / (running_max + 1e-9)
|
| 77 |
+
vec[9] = float(np.min(drawdowns))
|
| 78 |
+
|
| 79 |
+
# [10] skewness
|
| 80 |
+
mu, sigma = np.mean(ret), np.std(ret) or 1
|
| 81 |
+
vec[10] = float(np.clip(np.mean(((ret - mu) / sigma) ** 3), -3, 3))
|
| 82 |
+
|
| 83 |
+
# [11] excess kurtosis
|
| 84 |
+
vec[11] = float(np.clip(np.mean(((ret - mu) / sigma) ** 4) - 3, -3, 3))
|
| 85 |
+
|
| 86 |
+
# [12-15] return quartiles normalised by vol
|
| 87 |
+
if sigma > 0:
|
| 88 |
+
qs = np.quantile(ret, [0.25, 0.5, 0.75, 0.90])
|
| 89 |
+
vec[12:16] = np.clip(qs / sigma, -3, 3).astype(np.float32)
|
| 90 |
+
|
| 91 |
+
return vec
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
| 95 |
+
denom = (np.linalg.norm(a) * np.linalg.norm(b)) + 1e-9
|
| 96 |
+
return float(np.dot(a, b) / denom)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def find_similar(
|
| 100 |
+
target_embedding: list[float],
|
| 101 |
+
all_embeddings: dict[str, list[float]],
|
| 102 |
+
top_n: int = 6,
|
| 103 |
+
exclude_symbol: Optional[str] = None,
|
| 104 |
+
) -> list[dict]:
|
| 105 |
+
"""Return top_n most similar stocks by cosine similarity."""
|
| 106 |
+
target = np.array(target_embedding, dtype=np.float32)
|
| 107 |
+
scores = []
|
| 108 |
+
for symbol, emb in all_embeddings.items():
|
| 109 |
+
if exclude_symbol and symbol.upper() == exclude_symbol.upper():
|
| 110 |
+
continue
|
| 111 |
+
sim = cosine_similarity(target, np.array(emb, dtype=np.float32))
|
| 112 |
+
scores.append({"symbol": symbol, "similarity": round(sim, 4)})
|
| 113 |
+
scores.sort(key=lambda x: x["similarity"], reverse=True)
|
| 114 |
+
return scores[:top_n]
|
app/models/tft_predictor.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
pytorch-forecasting TFT inference for IDX stock price prediction.
|
| 3 |
+
|
| 4 |
+
Loads from Lightning checkpoint (.ckpt) produced by train_colab.py.
|
| 5 |
+
Uses pytorch-forecasting's TimeSeriesDataSet + TemporalFusionTransformer.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
from datetime import datetime, timedelta
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
from app.services.feature_engineer import SEQUENCE_LEN, FEATURE_COLS, build_features
|
| 15 |
+
|
| 16 |
+
FORECAST_HORIZON = 30
|
| 17 |
+
ENCODER_LENGTH = SEQUENCE_LEN # 60
|
| 18 |
+
QUANTILES = [0.1, 0.5, 0.9]
|
| 19 |
+
N_QUANTILES = len(QUANTILES)
|
| 20 |
+
|
| 21 |
+
TARGET = "close_norm"
|
| 22 |
+
KNOWN_REALS = ["day_sin", "day_cos", "month_sin", "month_cos"]
|
| 23 |
+
UNKNOWN_REALS = ["close_norm", "volume_norm", "rsi", "macd_norm", "bb_width", "atr_norm", "obv_norm"]
|
| 24 |
+
|
| 25 |
+
# Column index lookup for build_features() output
|
| 26 |
+
_FEAT_IDX = {col: i for i, col in enumerate(FEATURE_COLS)}
|
| 27 |
+
|
| 28 |
+
# ββ Model / params caching ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
+
|
| 30 |
+
_model = None
|
| 31 |
+
_model_path_cached: Optional[str] = None
|
| 32 |
+
_ds_params: Optional[dict] = None
|
| 33 |
+
_ds_params_path_cached: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _maybe_download(filename: str, local_path: str) -> bool:
|
| 37 |
+
"""Download a file from HF Hub if not present locally."""
|
| 38 |
+
if os.path.exists(local_path):
|
| 39 |
+
return True
|
| 40 |
+
import app.config as cfg
|
| 41 |
+
if not cfg.MODEL_REPO or not cfg.HF_TOKEN:
|
| 42 |
+
return False
|
| 43 |
+
try:
|
| 44 |
+
from huggingface_hub import hf_hub_download
|
| 45 |
+
local = hf_hub_download(
|
| 46 |
+
repo_id=cfg.MODEL_REPO,
|
| 47 |
+
filename=filename,
|
| 48 |
+
token=cfg.HF_TOKEN,
|
| 49 |
+
local_dir=os.path.dirname(local_path),
|
| 50 |
+
)
|
| 51 |
+
if local != local_path:
|
| 52 |
+
import shutil
|
| 53 |
+
shutil.copy2(local, local_path)
|
| 54 |
+
return os.path.exists(local_path)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"[tft] Could not download {filename} from HF Hub: {e}")
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_model(model_path: str):
|
| 61 |
+
"""Load and cache the pytorch-forecasting TFT from a Lightning checkpoint."""
|
| 62 |
+
global _model, _model_path_cached
|
| 63 |
+
if _model is not None and _model_path_cached == model_path:
|
| 64 |
+
return _model
|
| 65 |
+
|
| 66 |
+
_maybe_download("tft_stock.ckpt", model_path)
|
| 67 |
+
if not os.path.exists(model_path):
|
| 68 |
+
raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
|
| 69 |
+
|
| 70 |
+
from pytorch_forecasting import TemporalFusionTransformer
|
| 71 |
+
model = TemporalFusionTransformer.load_from_checkpoint(model_path, map_location="cpu")
|
| 72 |
+
model.eval()
|
| 73 |
+
_model = model
|
| 74 |
+
_model_path_cached = model_path
|
| 75 |
+
print(f"[tft] Loaded pytorch-forecasting TFT from {model_path}")
|
| 76 |
+
return model
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_dataset_params(params_path: str) -> dict:
|
| 80 |
+
"""Load and cache the TimeSeriesDataSet parameters saved during Colab training."""
|
| 81 |
+
global _ds_params, _ds_params_path_cached
|
| 82 |
+
if _ds_params is not None and _ds_params_path_cached == params_path:
|
| 83 |
+
return _ds_params
|
| 84 |
+
|
| 85 |
+
_maybe_download("dataset_params.pt", params_path)
|
| 86 |
+
if not os.path.exists(params_path):
|
| 87 |
+
raise FileNotFoundError(f"Dataset params not found: {params_path}")
|
| 88 |
+
|
| 89 |
+
params = torch.load(params_path, map_location="cpu", weights_only=False)
|
| 90 |
+
_ds_params = params
|
| 91 |
+
_ds_params_path_cached = params_path
|
| 92 |
+
print(f"[tft] Loaded dataset params from {params_path}")
|
| 93 |
+
return params
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ββ Inference DataFrame builder βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 97 |
+
|
| 98 |
+
def _build_inference_df(
|
| 99 |
+
closes: np.ndarray,
|
| 100 |
+
volumes: np.ndarray,
|
| 101 |
+
timestamps: np.ndarray,
|
| 102 |
+
symbol: str,
|
| 103 |
+
) -> pd.DataFrame:
|
| 104 |
+
"""
|
| 105 |
+
Build a DataFrame with ENCODER_LENGTH encoder rows + FORECAST_HORIZON future rows.
|
| 106 |
+
The encoder rows contain real feature values; future rows have only known reals
|
| 107 |
+
(day/month cyclicals) β the decoder does not use unknown future reals.
|
| 108 |
+
"""
|
| 109 |
+
features = build_features(closes, volumes, timestamps) # (T, 11)
|
| 110 |
+
if len(features) < ENCODER_LENGTH:
|
| 111 |
+
raise ValueError(f"Need at least {ENCODER_LENGTH} candles, got {len(features)}")
|
| 112 |
+
|
| 113 |
+
features = features[-ENCODER_LENGTH:]
|
| 114 |
+
ts_slice = timestamps[-len(features):]
|
| 115 |
+
|
| 116 |
+
# Timestamps β Python datetimes
|
| 117 |
+
dates = [datetime.utcfromtimestamp(int(ts)) for ts in ts_slice]
|
| 118 |
+
|
| 119 |
+
# Build encoder rows
|
| 120 |
+
rows = []
|
| 121 |
+
for i, (feat_row, dt) in enumerate(zip(features, dates)):
|
| 122 |
+
row: dict = {
|
| 123 |
+
"ticker": symbol,
|
| 124 |
+
"time_idx": i,
|
| 125 |
+
"date": dt,
|
| 126 |
+
}
|
| 127 |
+
for col in UNKNOWN_REALS + KNOWN_REALS:
|
| 128 |
+
row[col] = float(feat_row[_FEAT_IDX[col]])
|
| 129 |
+
rows.append(row)
|
| 130 |
+
|
| 131 |
+
encoder_df = pd.DataFrame(rows)
|
| 132 |
+
|
| 133 |
+
# Build future decoder rows (known reals computed from calendar)
|
| 134 |
+
last_date = dates[-1]
|
| 135 |
+
future_rows = []
|
| 136 |
+
for i in range(1, FORECAST_HORIZON + 1):
|
| 137 |
+
future_date = last_date + timedelta(days=i)
|
| 138 |
+
future_rows.append({
|
| 139 |
+
"ticker": symbol,
|
| 140 |
+
"time_idx": ENCODER_LENGTH + i - 1,
|
| 141 |
+
"date": future_date,
|
| 142 |
+
# Unknown reals: placeholder values (not used in decoder future steps)
|
| 143 |
+
TARGET: 0.0,
|
| 144 |
+
"volume_norm": 0.0,
|
| 145 |
+
"rsi": 0.5,
|
| 146 |
+
"macd_norm": 0.0,
|
| 147 |
+
"bb_width": 0.0,
|
| 148 |
+
"atr_norm": 0.0,
|
| 149 |
+
"obv_norm": 0.0,
|
| 150 |
+
# Known reals: actual calendar features
|
| 151 |
+
"day_sin": float(np.sin(2 * np.pi * future_date.weekday() / 5)),
|
| 152 |
+
"day_cos": float(np.cos(2 * np.pi * future_date.weekday() / 5)),
|
| 153 |
+
"month_sin": float(np.sin(2 * np.pi * future_date.month / 12)),
|
| 154 |
+
"month_cos": float(np.cos(2 * np.pi * future_date.month / 12)),
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
return pd.concat([encoder_df, pd.DataFrame(future_rows)], ignore_index=True)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 161 |
+
|
| 162 |
+
def predict_quantiles(
|
| 163 |
+
closes: np.ndarray,
|
| 164 |
+
volumes: np.ndarray,
|
| 165 |
+
timestamps: np.ndarray,
|
| 166 |
+
days: int,
|
| 167 |
+
model_path: str,
|
| 168 |
+
dataset_params_path: Optional[str] = None,
|
| 169 |
+
symbol: str = "UNKNOWN",
|
| 170 |
+
) -> dict:
|
| 171 |
+
"""
|
| 172 |
+
Run pytorch-forecasting TFT inference for `days` forecast horizon.
|
| 173 |
+
Returns quantile predictions as price levels (denormalized).
|
| 174 |
+
"""
|
| 175 |
+
if dataset_params_path is None:
|
| 176 |
+
dataset_params_path = model_path.replace("tft_stock.ckpt", "dataset_params.pt")
|
| 177 |
+
|
| 178 |
+
model = load_model(model_path)
|
| 179 |
+
ds_params = load_dataset_params(dataset_params_path)
|
| 180 |
+
|
| 181 |
+
days = max(1, min(days, FORECAST_HORIZON))
|
| 182 |
+
current_price = float(closes[-1])
|
| 183 |
+
roll_mean = float(np.mean(closes[-30:]))
|
| 184 |
+
roll_std = float(np.std(closes[-30:])) or 1.0
|
| 185 |
+
|
| 186 |
+
# Build inference DataFrame
|
| 187 |
+
full_df = _build_inference_df(closes, volumes, timestamps, symbol)
|
| 188 |
+
|
| 189 |
+
# Reconstruct TimeSeriesDataSet from training-time parameters.
|
| 190 |
+
# from_parameters() reuses the fitted categorical encoder (ticker β int),
|
| 191 |
+
# so unknown tickers fall back to the UNK embedding gracefully.
|
| 192 |
+
from pytorch_forecasting import TimeSeriesDataSet
|
| 193 |
+
|
| 194 |
+
pred_ds = TimeSeriesDataSet.from_parameters(
|
| 195 |
+
ds_params,
|
| 196 |
+
full_df,
|
| 197 |
+
predict=True, # one sample per group, from the end of data
|
| 198 |
+
stop_randomization=True,
|
| 199 |
+
min_encoder_length=ENCODER_LENGTH,
|
| 200 |
+
max_encoder_length=ENCODER_LENGTH,
|
| 201 |
+
min_prediction_length=FORECAST_HORIZON,
|
| 202 |
+
max_prediction_length=FORECAST_HORIZON,
|
| 203 |
+
min_prediction_idx=None,
|
| 204 |
+
)
|
| 205 |
+
pred_dl = pred_ds.to_dataloader(train=False, batch_size=1, num_workers=0)
|
| 206 |
+
|
| 207 |
+
# Predict β returns tensor of shape (1, FORECAST_HORIZON, N_QUANTILES)
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
raw = model.predict(pred_dl, mode="quantiles", return_x=False)
|
| 210 |
+
|
| 211 |
+
# Handle both tensor and list returns
|
| 212 |
+
if isinstance(raw, torch.Tensor):
|
| 213 |
+
preds = raw.squeeze(0).cpu().numpy() # (FORECAST_HORIZON, 3)
|
| 214 |
+
else:
|
| 215 |
+
preds = np.array(raw[0]) # (FORECAST_HORIZON, 3)
|
| 216 |
+
|
| 217 |
+
preds = preds[:days] # slice to requested horizon
|
| 218 |
+
|
| 219 |
+
# Denormalize rolling z-score β price levels
|
| 220 |
+
q10 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 0]]
|
| 221 |
+
q50 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 1]]
|
| 222 |
+
q90 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 2]]
|
| 223 |
+
|
| 224 |
+
# Enforce monotonic bounds (q10 β€ q50 β€ q90)
|
| 225 |
+
for i in range(days):
|
| 226 |
+
q10[i] = min(q10[i], q50[i])
|
| 227 |
+
q90[i] = max(q90[i], q50[i])
|
| 228 |
+
|
| 229 |
+
final_price = q50[-1]
|
| 230 |
+
trend = (
|
| 231 |
+
"bullish" if final_price > current_price * 1.005
|
| 232 |
+
else "bearish" if final_price < current_price * 0.995
|
| 233 |
+
else "sideways"
|
| 234 |
+
)
|
| 235 |
+
change_pct = (final_price - current_price) / current_price * 100
|
| 236 |
+
|
| 237 |
+
return {
|
| 238 |
+
"method": "tft",
|
| 239 |
+
"predictions": q50,
|
| 240 |
+
"lower_bound": q10,
|
| 241 |
+
"upper_bound": q90,
|
| 242 |
+
"target_price": final_price,
|
| 243 |
+
"trend": trend,
|
| 244 |
+
"change_pct": round(change_pct, 2),
|
| 245 |
+
"confidence": 72,
|
| 246 |
+
"support": round(min(q10), 2),
|
| 247 |
+
"resistance": round(max(q90), 2),
|
| 248 |
+
"feature_importance": {}, # TFT attention weights available via interpret_output() if needed
|
| 249 |
+
}
|