will702 commited on
Commit
7af21f9
Β·
1 Parent(s): 65a23c3

fix: track app/models/ package, ignore only top-level /models/

Browse files
.gitignore CHANGED
@@ -2,7 +2,7 @@ venv/
2
  __pycache__/
3
  *.pyc
4
  .env
5
- models/
6
  *.ckpt
7
  *.pt
8
  checkpoints/
 
2
  __pycache__/
3
  *.pyc
4
  .env
5
+ /models/
6
  *.ckpt
7
  *.pt
8
  checkpoints/
app/models/__init__.py ADDED
File without changes
app/models/ddg_da.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DDG-DA: Data Distribution Generation for Predictable Concept Drift Adaptation.
3
+
4
+ Implements:
5
+ 1. DriftPredictorMLP β€” small MLP that predicts the next distribution snapshot
6
+ 2. DDGDAPredictor β€” orchestrator: drift detection + drift score reporting
7
+
8
+ Note: DDG-DA head fine-tuning is model-architecture-specific and is handled
9
+ separately. This module provides drift detection that works with any TFT backend.
10
+
11
+ Reference: "DDG-DA: Data Distribution Generation for Predictable Concept Drift
12
+ Adaptation" (Data-Centric AI workshop).
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import os
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from typing import Optional
21
+
22
+ from app.services.concept_drift import (
23
+ SNAPSHOT_DIM,
24
+ K_HISTORY,
25
+ SNAPSHOT_WINDOW,
26
+ DriftState,
27
+ compute_snapshot,
28
+ extract_snapshots_from_series,
29
+ compute_drift_score,
30
+ )
31
+
32
+
33
+ # ── 1. Drift Predictor MLP ────────────────────────────────────────────────────
34
+
35
+ class DriftPredictorMLP(nn.Module):
36
+ """
37
+ Predicts the NEXT distribution snapshot from the last K snapshots.
38
+
39
+ Input: (batch, K * SNAPSHOT_DIM) = (batch, 8 * 44 = 352)
40
+ Output: (batch, SNAPSHOT_DIM) = (batch, 44)
41
+
42
+ ~51K parameters, ~200KB model file on disk.
43
+ """
44
+
45
+ def __init__(self, k_history: int = K_HISTORY, snapshot_dim: int = SNAPSHOT_DIM, hidden: int = 128):
46
+ super().__init__()
47
+ self.input_dim = k_history * snapshot_dim
48
+ self.net = nn.Sequential(
49
+ nn.Linear(self.input_dim, hidden),
50
+ nn.ELU(),
51
+ nn.Dropout(0.1),
52
+ nn.Linear(hidden, snapshot_dim),
53
+ )
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ return self.net(x)
57
+
58
+ def predict_next(self, history_snapshots: np.ndarray) -> np.ndarray:
59
+ """
60
+ Predict next snapshot from (K, 44) history β†’ (44,) prediction.
61
+ Pads with zeros if fewer than K history snapshots are available.
62
+ """
63
+ k = history_snapshots.shape[0]
64
+ required = self.input_dim // SNAPSHOT_DIM
65
+ if k < required:
66
+ padding = np.zeros((required - k, SNAPSHOT_DIM), dtype=np.float32)
67
+ history_snapshots = np.concatenate([padding, history_snapshots], axis=0)
68
+ x = torch.tensor(history_snapshots.flatten()[None], dtype=torch.float32) # (1, 352)
69
+ with torch.no_grad():
70
+ out = self.net(x).squeeze(0).numpy() # (44,)
71
+ return out
72
+
73
+
74
+ # ── 2. DDG-DA Predictor (Orchestrator) ───────────────────────────────────────
75
+
76
+
77
+ class DDGDAPredictor:
78
+ """
79
+ Orchestrates concept drift detection.
80
+
81
+ - Measures current feature distribution (snapshot)
82
+ - Scores drift vs. historical reference
83
+ - Reports drift_score and drift_detected for confidence adjustment
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ model_path: str,
89
+ k_history: int = K_HISTORY,
90
+ ):
91
+ self.mlp = DriftPredictorMLP(k_history=k_history)
92
+ self.k_history = k_history
93
+ self._load_mlp(model_path)
94
+
95
+ def _load_mlp(self, model_path: str) -> None:
96
+ if os.path.exists(model_path):
97
+ try:
98
+ state = torch.load(model_path, map_location="cpu", weights_only=True)
99
+ self.mlp.load_state_dict(state)
100
+ self.mlp.eval()
101
+ except Exception as e:
102
+ print(f"[ddg_da] Could not load MLP weights: {e}")
103
+
104
+ # ── Public API ────────────────────────────────────────────────────────────
105
+
106
+ def adapt(self, features: np.ndarray) -> DriftState:
107
+ """
108
+ Assess concept drift from the current feature distribution.
109
+
110
+ Args:
111
+ features: (T, N_FEATURES) feature matrix for the current symbol.
112
+ Should cover at least SNAPSHOT_WINDOW rows.
113
+ Returns:
114
+ DriftState with drift_score, drift_detected, and predicted_next_snapshot.
115
+ """
116
+ if len(features) < SNAPSHOT_WINDOW:
117
+ return DriftState(
118
+ snapshot=np.zeros(SNAPSHOT_DIM, dtype=np.float32),
119
+ drift_score=0.0,
120
+ drift_detected=False,
121
+ )
122
+
123
+ current_snap = compute_snapshot(features, window=SNAPSHOT_WINDOW)
124
+ all_snaps = extract_snapshots_from_series(features, window=SNAPSHOT_WINDOW)
125
+ if len(all_snaps) < 2:
126
+ return DriftState(snapshot=current_snap, drift_score=0.0, drift_detected=False)
127
+
128
+ drift_score, drift_detected = compute_drift_score(current_snap, all_snaps[:-1])
129
+
130
+ history_for_mlp = all_snaps[-self.k_history:]
131
+ predicted_next = self.mlp.predict_next(history_for_mlp)
132
+
133
+ return DriftState(
134
+ snapshot=current_snap,
135
+ drift_score=drift_score,
136
+ drift_detected=drift_detected,
137
+ predicted_next_snapshot=predicted_next,
138
+ )
139
+
140
+
141
+ # ── Module-level loader with caching ─────────────────────────────────────────
142
+
143
+ _ddg_da: Optional[DDGDAPredictor] = None
144
+ _ddg_da_path_cached: Optional[str] = None
145
+
146
+
147
+ def _maybe_download_ddg_da(model_path: str) -> bool:
148
+ """Download ddg_da.pt from HF Hub if not present locally."""
149
+ if os.path.exists(model_path):
150
+ return True
151
+ import app.config as cfg
152
+ if not cfg.MODEL_REPO or not cfg.HF_TOKEN:
153
+ return False
154
+ try:
155
+ from huggingface_hub import hf_hub_download
156
+ local = hf_hub_download(
157
+ repo_id=cfg.MODEL_REPO,
158
+ filename="ddg_da.pt",
159
+ token=cfg.HF_TOKEN,
160
+ local_dir=os.path.dirname(model_path),
161
+ )
162
+ if local != model_path:
163
+ import shutil
164
+ shutil.copy2(local, model_path)
165
+ return os.path.exists(model_path)
166
+ except Exception as e:
167
+ print(f"[ddg_da] Could not download from HF Hub: {e}")
168
+ return False
169
+
170
+
171
+ def load_ddg_da(model_path: str) -> Optional[DDGDAPredictor]:
172
+ """
173
+ Load (and cache) the DDG-DA drift predictor.
174
+ Returns None gracefully if ddg_da.pt is absent β€” base TFT still works.
175
+ """
176
+ global _ddg_da, _ddg_da_path_cached
177
+ if _ddg_da is not None and _ddg_da_path_cached == model_path:
178
+ return _ddg_da
179
+
180
+ _maybe_download_ddg_da(model_path)
181
+ if not os.path.exists(model_path):
182
+ return None
183
+
184
+ try:
185
+ predictor = DDGDAPredictor(model_path=model_path)
186
+ _ddg_da = predictor
187
+ _ddg_da_path_cached = model_path
188
+ return _ddg_da
189
+ except Exception as e:
190
+ print(f"[ddg_da] Failed to initialize DDGDAPredictor: {e}")
191
+ return None
app/models/embeddings.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stock similarity via 16-dimensional feature embeddings.
3
+ Stored in Supabase as JSONB float arrays (no pgvector required).
4
+ Cosine similarity computed here for the /similar endpoint.
5
+ """
6
+ import numpy as np
7
+ from typing import Optional
8
+ import pandas as pd
9
+
10
+
11
+ EMBED_DIM = 16
12
+
13
+
14
+ def _safe_norm(v: np.ndarray) -> float:
15
+ return float(np.std(v)) or 1.0
16
+
17
+
18
+ def compute_embedding(
19
+ closes: np.ndarray,
20
+ volumes: np.ndarray,
21
+ sector_id: int = 0,
22
+ ) -> np.ndarray:
23
+ """
24
+ 16-dim embedding per stock:
25
+ [0] sector (normalised 0-1)
26
+ [1] annualised return (capped Β±100%)
27
+ [2] annualised volatility
28
+ [3-6] return autocorrelations lag 1-4
29
+ [7] beta proxy (corr with overall index, approx)
30
+ [8] volume z-score trend (avg recent vs avg older)
31
+ [9] max drawdown
32
+ [10] skewness of returns
33
+ [11] kurtosis of returns
34
+ [12-15] rolling return quartiles (q25, q50, q75, q90)
35
+ """
36
+ vec = np.zeros(EMBED_DIM, dtype=np.float32)
37
+
38
+ if len(closes) < 30:
39
+ return vec
40
+
41
+ ret = np.diff(np.log(closes + 1e-9))
42
+ T = len(ret)
43
+
44
+ # [0] sector normalised 0-1 (max ~12 sectors on IDX)
45
+ vec[0] = min(sector_id / 12.0, 1.0)
46
+
47
+ # [1] annualised return, capped
48
+ ann_ret = np.mean(ret) * 252
49
+ vec[1] = float(np.clip(ann_ret, -1.0, 1.0))
50
+
51
+ # [2] annualised volatility
52
+ vec[2] = float(np.std(ret) * np.sqrt(252))
53
+
54
+ # [3-6] autocorrelations lag 1-4
55
+ for lag in range(1, 5):
56
+ if T > lag + 1:
57
+ corr = float(np.corrcoef(ret[:-lag], ret[lag:])[0, 1])
58
+ vec[2 + lag] = corr if not np.isnan(corr) else 0.0
59
+
60
+ # [7] beta proxy: correlation with its own 20-day rolling avg (smoother = lower beta)
61
+ s = pd.Series(closes)
62
+ trend = s.rolling(20, min_periods=5).mean().dropna().values
63
+ if len(trend) > 10:
64
+ corr = float(np.corrcoef(closes[-len(trend):], trend)[0, 1])
65
+ vec[7] = corr if not np.isnan(corr) else 0.0
66
+
67
+ # [8] volume trend: mean of recent 20 vs older 20
68
+ if len(volumes) >= 40:
69
+ recent = float(np.mean(volumes[-20:]))
70
+ older = float(np.mean(volumes[-40:-20])) or 1.0
71
+ vec[8] = float(np.clip((recent / older) - 1, -1, 1))
72
+
73
+ # [9] max drawdown
74
+ cum = np.cumprod(1 + ret)
75
+ running_max = np.maximum.accumulate(cum)
76
+ drawdowns = (cum - running_max) / (running_max + 1e-9)
77
+ vec[9] = float(np.min(drawdowns))
78
+
79
+ # [10] skewness
80
+ mu, sigma = np.mean(ret), np.std(ret) or 1
81
+ vec[10] = float(np.clip(np.mean(((ret - mu) / sigma) ** 3), -3, 3))
82
+
83
+ # [11] excess kurtosis
84
+ vec[11] = float(np.clip(np.mean(((ret - mu) / sigma) ** 4) - 3, -3, 3))
85
+
86
+ # [12-15] return quartiles normalised by vol
87
+ if sigma > 0:
88
+ qs = np.quantile(ret, [0.25, 0.5, 0.75, 0.90])
89
+ vec[12:16] = np.clip(qs / sigma, -3, 3).astype(np.float32)
90
+
91
+ return vec
92
+
93
+
94
+ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
95
+ denom = (np.linalg.norm(a) * np.linalg.norm(b)) + 1e-9
96
+ return float(np.dot(a, b) / denom)
97
+
98
+
99
+ def find_similar(
100
+ target_embedding: list[float],
101
+ all_embeddings: dict[str, list[float]],
102
+ top_n: int = 6,
103
+ exclude_symbol: Optional[str] = None,
104
+ ) -> list[dict]:
105
+ """Return top_n most similar stocks by cosine similarity."""
106
+ target = np.array(target_embedding, dtype=np.float32)
107
+ scores = []
108
+ for symbol, emb in all_embeddings.items():
109
+ if exclude_symbol and symbol.upper() == exclude_symbol.upper():
110
+ continue
111
+ sim = cosine_similarity(target, np.array(emb, dtype=np.float32))
112
+ scores.append({"symbol": symbol, "similarity": round(sim, 4)})
113
+ scores.sort(key=lambda x: x["similarity"], reverse=True)
114
+ return scores[:top_n]
app/models/tft_predictor.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ pytorch-forecasting TFT inference for IDX stock price prediction.
3
+
4
+ Loads from Lightning checkpoint (.ckpt) produced by train_colab.py.
5
+ Uses pytorch-forecasting's TimeSeriesDataSet + TemporalFusionTransformer.
6
+ """
7
+ import os
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ from datetime import datetime, timedelta
12
+ from typing import Optional
13
+
14
+ from app.services.feature_engineer import SEQUENCE_LEN, FEATURE_COLS, build_features
15
+
16
+ FORECAST_HORIZON = 30
17
+ ENCODER_LENGTH = SEQUENCE_LEN # 60
18
+ QUANTILES = [0.1, 0.5, 0.9]
19
+ N_QUANTILES = len(QUANTILES)
20
+
21
+ TARGET = "close_norm"
22
+ KNOWN_REALS = ["day_sin", "day_cos", "month_sin", "month_cos"]
23
+ UNKNOWN_REALS = ["close_norm", "volume_norm", "rsi", "macd_norm", "bb_width", "atr_norm", "obv_norm"]
24
+
25
+ # Column index lookup for build_features() output
26
+ _FEAT_IDX = {col: i for i, col in enumerate(FEATURE_COLS)}
27
+
28
+ # ── Model / params caching ────────────────────────────────────────────────────
29
+
30
+ _model = None
31
+ _model_path_cached: Optional[str] = None
32
+ _ds_params: Optional[dict] = None
33
+ _ds_params_path_cached: Optional[str] = None
34
+
35
+
36
+ def _maybe_download(filename: str, local_path: str) -> bool:
37
+ """Download a file from HF Hub if not present locally."""
38
+ if os.path.exists(local_path):
39
+ return True
40
+ import app.config as cfg
41
+ if not cfg.MODEL_REPO or not cfg.HF_TOKEN:
42
+ return False
43
+ try:
44
+ from huggingface_hub import hf_hub_download
45
+ local = hf_hub_download(
46
+ repo_id=cfg.MODEL_REPO,
47
+ filename=filename,
48
+ token=cfg.HF_TOKEN,
49
+ local_dir=os.path.dirname(local_path),
50
+ )
51
+ if local != local_path:
52
+ import shutil
53
+ shutil.copy2(local, local_path)
54
+ return os.path.exists(local_path)
55
+ except Exception as e:
56
+ print(f"[tft] Could not download {filename} from HF Hub: {e}")
57
+ return False
58
+
59
+
60
+ def load_model(model_path: str):
61
+ """Load and cache the pytorch-forecasting TFT from a Lightning checkpoint."""
62
+ global _model, _model_path_cached
63
+ if _model is not None and _model_path_cached == model_path:
64
+ return _model
65
+
66
+ _maybe_download("tft_stock.ckpt", model_path)
67
+ if not os.path.exists(model_path):
68
+ raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
69
+
70
+ from pytorch_forecasting import TemporalFusionTransformer
71
+ model = TemporalFusionTransformer.load_from_checkpoint(model_path, map_location="cpu")
72
+ model.eval()
73
+ _model = model
74
+ _model_path_cached = model_path
75
+ print(f"[tft] Loaded pytorch-forecasting TFT from {model_path}")
76
+ return model
77
+
78
+
79
+ def load_dataset_params(params_path: str) -> dict:
80
+ """Load and cache the TimeSeriesDataSet parameters saved during Colab training."""
81
+ global _ds_params, _ds_params_path_cached
82
+ if _ds_params is not None and _ds_params_path_cached == params_path:
83
+ return _ds_params
84
+
85
+ _maybe_download("dataset_params.pt", params_path)
86
+ if not os.path.exists(params_path):
87
+ raise FileNotFoundError(f"Dataset params not found: {params_path}")
88
+
89
+ params = torch.load(params_path, map_location="cpu", weights_only=False)
90
+ _ds_params = params
91
+ _ds_params_path_cached = params_path
92
+ print(f"[tft] Loaded dataset params from {params_path}")
93
+ return params
94
+
95
+
96
+ # ── Inference DataFrame builder ───────────────────────────────────────────────
97
+
98
+ def _build_inference_df(
99
+ closes: np.ndarray,
100
+ volumes: np.ndarray,
101
+ timestamps: np.ndarray,
102
+ symbol: str,
103
+ ) -> pd.DataFrame:
104
+ """
105
+ Build a DataFrame with ENCODER_LENGTH encoder rows + FORECAST_HORIZON future rows.
106
+ The encoder rows contain real feature values; future rows have only known reals
107
+ (day/month cyclicals) β€” the decoder does not use unknown future reals.
108
+ """
109
+ features = build_features(closes, volumes, timestamps) # (T, 11)
110
+ if len(features) < ENCODER_LENGTH:
111
+ raise ValueError(f"Need at least {ENCODER_LENGTH} candles, got {len(features)}")
112
+
113
+ features = features[-ENCODER_LENGTH:]
114
+ ts_slice = timestamps[-len(features):]
115
+
116
+ # Timestamps β†’ Python datetimes
117
+ dates = [datetime.utcfromtimestamp(int(ts)) for ts in ts_slice]
118
+
119
+ # Build encoder rows
120
+ rows = []
121
+ for i, (feat_row, dt) in enumerate(zip(features, dates)):
122
+ row: dict = {
123
+ "ticker": symbol,
124
+ "time_idx": i,
125
+ "date": dt,
126
+ }
127
+ for col in UNKNOWN_REALS + KNOWN_REALS:
128
+ row[col] = float(feat_row[_FEAT_IDX[col]])
129
+ rows.append(row)
130
+
131
+ encoder_df = pd.DataFrame(rows)
132
+
133
+ # Build future decoder rows (known reals computed from calendar)
134
+ last_date = dates[-1]
135
+ future_rows = []
136
+ for i in range(1, FORECAST_HORIZON + 1):
137
+ future_date = last_date + timedelta(days=i)
138
+ future_rows.append({
139
+ "ticker": symbol,
140
+ "time_idx": ENCODER_LENGTH + i - 1,
141
+ "date": future_date,
142
+ # Unknown reals: placeholder values (not used in decoder future steps)
143
+ TARGET: 0.0,
144
+ "volume_norm": 0.0,
145
+ "rsi": 0.5,
146
+ "macd_norm": 0.0,
147
+ "bb_width": 0.0,
148
+ "atr_norm": 0.0,
149
+ "obv_norm": 0.0,
150
+ # Known reals: actual calendar features
151
+ "day_sin": float(np.sin(2 * np.pi * future_date.weekday() / 5)),
152
+ "day_cos": float(np.cos(2 * np.pi * future_date.weekday() / 5)),
153
+ "month_sin": float(np.sin(2 * np.pi * future_date.month / 12)),
154
+ "month_cos": float(np.cos(2 * np.pi * future_date.month / 12)),
155
+ })
156
+
157
+ return pd.concat([encoder_df, pd.DataFrame(future_rows)], ignore_index=True)
158
+
159
+
160
+ # ── Inference ─────────────────────────────────────────────────────────────────
161
+
162
+ def predict_quantiles(
163
+ closes: np.ndarray,
164
+ volumes: np.ndarray,
165
+ timestamps: np.ndarray,
166
+ days: int,
167
+ model_path: str,
168
+ dataset_params_path: Optional[str] = None,
169
+ symbol: str = "UNKNOWN",
170
+ ) -> dict:
171
+ """
172
+ Run pytorch-forecasting TFT inference for `days` forecast horizon.
173
+ Returns quantile predictions as price levels (denormalized).
174
+ """
175
+ if dataset_params_path is None:
176
+ dataset_params_path = model_path.replace("tft_stock.ckpt", "dataset_params.pt")
177
+
178
+ model = load_model(model_path)
179
+ ds_params = load_dataset_params(dataset_params_path)
180
+
181
+ days = max(1, min(days, FORECAST_HORIZON))
182
+ current_price = float(closes[-1])
183
+ roll_mean = float(np.mean(closes[-30:]))
184
+ roll_std = float(np.std(closes[-30:])) or 1.0
185
+
186
+ # Build inference DataFrame
187
+ full_df = _build_inference_df(closes, volumes, timestamps, symbol)
188
+
189
+ # Reconstruct TimeSeriesDataSet from training-time parameters.
190
+ # from_parameters() reuses the fitted categorical encoder (ticker β†’ int),
191
+ # so unknown tickers fall back to the UNK embedding gracefully.
192
+ from pytorch_forecasting import TimeSeriesDataSet
193
+
194
+ pred_ds = TimeSeriesDataSet.from_parameters(
195
+ ds_params,
196
+ full_df,
197
+ predict=True, # one sample per group, from the end of data
198
+ stop_randomization=True,
199
+ min_encoder_length=ENCODER_LENGTH,
200
+ max_encoder_length=ENCODER_LENGTH,
201
+ min_prediction_length=FORECAST_HORIZON,
202
+ max_prediction_length=FORECAST_HORIZON,
203
+ min_prediction_idx=None,
204
+ )
205
+ pred_dl = pred_ds.to_dataloader(train=False, batch_size=1, num_workers=0)
206
+
207
+ # Predict β€” returns tensor of shape (1, FORECAST_HORIZON, N_QUANTILES)
208
+ with torch.no_grad():
209
+ raw = model.predict(pred_dl, mode="quantiles", return_x=False)
210
+
211
+ # Handle both tensor and list returns
212
+ if isinstance(raw, torch.Tensor):
213
+ preds = raw.squeeze(0).cpu().numpy() # (FORECAST_HORIZON, 3)
214
+ else:
215
+ preds = np.array(raw[0]) # (FORECAST_HORIZON, 3)
216
+
217
+ preds = preds[:days] # slice to requested horizon
218
+
219
+ # Denormalize rolling z-score β†’ price levels
220
+ q10 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 0]]
221
+ q50 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 1]]
222
+ q90 = [max(0.0, round(float(z * roll_std + roll_mean), 2)) for z in preds[:, 2]]
223
+
224
+ # Enforce monotonic bounds (q10 ≀ q50 ≀ q90)
225
+ for i in range(days):
226
+ q10[i] = min(q10[i], q50[i])
227
+ q90[i] = max(q90[i], q50[i])
228
+
229
+ final_price = q50[-1]
230
+ trend = (
231
+ "bullish" if final_price > current_price * 1.005
232
+ else "bearish" if final_price < current_price * 0.995
233
+ else "sideways"
234
+ )
235
+ change_pct = (final_price - current_price) / current_price * 100
236
+
237
+ return {
238
+ "method": "tft",
239
+ "predictions": q50,
240
+ "lower_bound": q10,
241
+ "upper_bound": q90,
242
+ "target_price": final_price,
243
+ "trend": trend,
244
+ "change_pct": round(change_pct, 2),
245
+ "confidence": 72,
246
+ "support": round(min(q10), 2),
247
+ "resistance": round(max(q90), 2),
248
+ "feature_importance": {}, # TFT attention weights available via interpret_output() if needed
249
+ }