vmjn commited on
Commit
3e07fe0
·
verified ·
1 Parent(s): c4c3637

add Chronos, TimesFM, FinBERT endpoints — 4-model inference Space

Browse files
Files changed (2) hide show
  1. Dockerfile +4 -1
  2. app.py +256 -77
Dockerfile CHANGED
@@ -26,7 +26,10 @@ RUN pip install --user --no-cache-dir \
26
  "websockets>=13.0" \
27
  "einops>=0.7" \
28
  "safetensors>=0.4" \
29
- "tqdm>=4.66"
 
 
 
30
 
31
  COPY --chown=user . /home/user/app
32
 
 
26
  "websockets>=13.0" \
27
  "einops>=0.7" \
28
  "safetensors>=0.4" \
29
+ "tqdm>=4.66" \
30
+ "transformers>=4.40,<5.0" \
31
+ "chronos-forecasting>=1.5.2" \
32
+ "timesfm[torch]>=1.3.0"
33
 
34
  COPY --chown=user . /home/user/app
35
 
app.py CHANGED
@@ -1,99 +1,98 @@
1
- """Kronos-small financial forecast minimal MCP-friendly wrapper."""
2
- import os
3
- import pandas as pd
 
 
 
 
 
 
 
 
 
4
  import numpy as np
 
5
  import torch
6
  import yfinance as yf
7
  import gradio as gr
8
- from model import Kronos, KronosTokenizer, KronosPredictor
9
 
10
- # Kronos-small: 24.7M params, fits free CPU Space comfortably.
11
- TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
12
- MODEL_ID = "NeoQuasar/Kronos-small"
13
- DEVICE = "cpu"
14
- MAX_CONTEXT = 512
15
 
16
- _predictor = None
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- def get_predictor():
20
- global _predictor
21
- if _predictor is None:
22
- tok = KronosTokenizer.from_pretrained(TOKENIZER_ID)
23
- mdl = Kronos.from_pretrained(MODEL_ID)
24
- _predictor = KronosPredictor(mdl, tok, device=DEVICE, max_context=MAX_CONTEXT)
25
- return _predictor
26
 
 
 
27
 
28
- def _infer_freq(symbol: str) -> str:
29
- # Indian MFs and some indices only have daily data on yfinance.
30
- return "1d"
31
 
 
 
 
32
 
33
- def forecast(symbol: str, lookback_days: int = 180, pred_days: int = 30) -> dict:
34
- """Run Kronos forecast for a symbol and return direction + predicted % change.
35
 
36
- Args:
37
- symbol: yfinance ticker (e.g. 'RELIANCE.NS', 'VOO', 'GOLDIETF.NS').
38
- lookback_days: historical window (default 180, max 500).
39
- pred_days: forecast horizon in days (default 30).
 
 
 
 
 
40
 
41
- Returns:
42
- dict with direction (+1/-1/0), pct_change (%), last_close, predicted_close,
43
- n_lookback, model, status.
44
- """
 
 
 
 
 
 
 
45
  symbol = (symbol or "").strip().upper()
46
  if not symbol:
47
  return {"status": "error", "error": "empty symbol"}
48
-
49
- lookback_days = int(max(32, min(lookback_days or 180, 500)))
50
- pred_days = int(max(1, min(pred_days or 30, 90)))
51
-
52
  try:
53
- df = yf.download(symbol, period=f"{lookback_days + 10}d",
54
- interval="1d", progress=False, auto_adjust=False)
55
- if df is None or df.empty or len(df) < 32:
56
- return {"status": "error", "error": f"no data for {symbol}", "n_lookback": 0}
57
- # flatten multiindex columns if present
58
- if isinstance(df.columns, pd.MultiIndex):
59
- df.columns = df.columns.get_level_values(0)
60
- df = df.reset_index()
61
- df.columns = [str(c).lower() for c in df.columns]
62
- df["timestamps"] = pd.to_datetime(df["date"])
63
- kdf = df[["timestamps", "open", "high", "low", "close", "volume"]].copy().tail(lookback_days)
64
- kdf = kdf.dropna().reset_index(drop=True)
65
  if len(kdf) < 32:
66
- return {"status": "error", "error": "insufficient clean data", "n_lookback": len(kdf)}
67
-
68
  x_df = kdf[["open", "high", "low", "close", "volume"]].copy()
69
  x_df["amount"] = x_df["close"] * x_df["volume"]
70
  x_timestamp = kdf["timestamps"]
71
- # build future timestamps: business days
72
  last = x_timestamp.iloc[-1]
73
  y_timestamp = pd.Series(pd.bdate_range(start=last + pd.Timedelta(days=1), periods=pred_days))
74
-
75
- predictor = get_predictor()
76
- pred_df = predictor.predict(
77
  df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
78
  pred_len=pred_days, T=1.0, top_p=0.9, sample_count=1, verbose=False,
79
  )
80
-
81
  last_close = float(kdf["close"].iloc[-1])
82
  pred_close = float(pred_df["close"].iloc[-1])
83
  pct = (pred_close - last_close) / last_close * 100.0
84
- direction = 1 if pct > 0.5 else (-1 if pct < -0.5 else 0)
85
-
86
  return {
87
- "status": "ok",
88
- "symbol": symbol,
89
- "model": MODEL_ID,
90
- "last_close": round(last_close, 4),
91
- "predicted_close": round(pred_close, 4),
92
- "pct_change": round(pct, 3),
93
- "direction": direction,
94
- "n_lookback": int(len(kdf)),
95
- "pred_days": pred_days,
96
- "pred_first_close": round(float(pred_df["close"].iloc[0]), 4),
97
  "pred_mean_close": round(float(pred_df["close"].mean()), 4),
98
  "pred_min_close": round(float(pred_df["close"].min()), 4),
99
  "pred_max_close": round(float(pred_df["close"].max()), 4),
@@ -102,18 +101,198 @@ def forecast(symbol: str, lookback_days: int = 180, pred_days: int = 30) -> dict
102
  return {"status": "error", "error": f"{type(e).__name__}: {e}", "symbol": symbol}
103
 
104
 
105
- demo = gr.Interface(
106
- fn=forecast,
107
- inputs=[
108
- gr.Textbox(label="Symbol (yfinance)", value="RELIANCE.NS"),
109
- gr.Slider(32, 500, value=180, step=1, label="Lookback days"),
110
- gr.Slider(1, 90, value=30, step=1, label="Prediction days"),
111
- ],
112
- outputs=gr.JSON(label="Kronos forecast"),
113
- title="Kronos-small forecast",
114
- description="Direction + predicted % change from Kronos (finance-native foundation model).",
115
- api_name="forecast",
116
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  if __name__ == "__main__":
119
  demo.launch(server_name="0.0.0.0", server_port=7860, mcp_server=True)
 
1
+ """Multi-model Investment OS inference Space.
2
+
3
+ Endpoints (Gradio + MCP):
4
+ - /forecast — Kronos-small (finance-native candlestick foundation model)
5
+ - /forecast_chronos — amazon/chronos-bolt-tiny (generic TSFM, CPU-fast)
6
+ - /forecast_timesfm — google/timesfm-2.5-200m-pytorch (Google TSFM)
7
+ - /score_sentiment — ProsusAI/finbert (financial sentiment)
8
+
9
+ All models lazy-loaded on first call. CPU-only.
10
+ """
11
+ from __future__ import annotations
12
+
13
  import numpy as np
14
+ import pandas as pd
15
  import torch
16
  import yfinance as yf
17
  import gradio as gr
 
18
 
 
 
 
 
 
19
 
20
+ # -----------------------------------------------------------------------------
21
+ # Shared: yfinance OHLC loader
22
+ # -----------------------------------------------------------------------------
23
 
24
+ def _load_ohlc(symbol: str, lookback_days: int) -> pd.DataFrame:
25
+ df = yf.download(symbol, period=f"{lookback_days + 10}d",
26
+ interval="1d", progress=False, auto_adjust=False)
27
+ if df is None or df.empty:
28
+ return pd.DataFrame()
29
+ if isinstance(df.columns, pd.MultiIndex):
30
+ df.columns = df.columns.get_level_values(0)
31
+ df = df.reset_index()
32
+ df.columns = [str(c).lower() for c in df.columns]
33
+ if "date" not in df.columns:
34
+ return pd.DataFrame()
35
+ df["timestamps"] = pd.to_datetime(df["date"])
36
+ keep = ["timestamps", "open", "high", "low", "close", "volume"]
37
+ df = df[[c for c in keep if c in df.columns]].dropna().tail(lookback_days).reset_index(drop=True)
38
+ return df
39
 
 
 
 
 
 
 
 
40
 
41
+ def _direction(pct: float) -> int:
42
+ return 1 if pct > 0.5 else (-1 if pct < -0.5 else 0)
43
 
 
 
 
44
 
45
+ def _clamp(lb, pd_, min_lb=32, max_lb=500, max_pred=90):
46
+ return (int(max(min_lb, min(int(lb or 180), max_lb))),
47
+ int(max(1, min(int(pd_ or 30), max_pred))))
48
 
 
 
49
 
50
+ # -----------------------------------------------------------------------------
51
+ # Kronos NeoQuasar/Kronos-small
52
+ # -----------------------------------------------------------------------------
53
+ from model import Kronos, KronosTokenizer, KronosPredictor
54
+
55
+ KRONOS_MODEL_ID = "NeoQuasar/Kronos-small"
56
+ KRONOS_TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
57
+ _kronos = None
58
+
59
 
60
+ def _get_kronos():
61
+ global _kronos
62
+ if _kronos is None:
63
+ tok = KronosTokenizer.from_pretrained(KRONOS_TOKENIZER_ID)
64
+ mdl = Kronos.from_pretrained(KRONOS_MODEL_ID)
65
+ _kronos = KronosPredictor(mdl, tok, device="cpu", max_context=512)
66
+ return _kronos
67
+
68
+
69
+ def forecast(symbol: str, lookback_days: int = 180, pred_days: int = 30) -> dict:
70
+ """Kronos-small (finance-native) forecast. Returns direction + % change."""
71
  symbol = (symbol or "").strip().upper()
72
  if not symbol:
73
  return {"status": "error", "error": "empty symbol"}
74
+ lookback_days, pred_days = _clamp(lookback_days, pred_days)
 
 
 
75
  try:
76
+ kdf = _load_ohlc(symbol, lookback_days)
 
 
 
 
 
 
 
 
 
 
 
77
  if len(kdf) < 32:
78
+ return {"status": "error", "error": f"insufficient data for {symbol}", "n_lookback": len(kdf)}
 
79
  x_df = kdf[["open", "high", "low", "close", "volume"]].copy()
80
  x_df["amount"] = x_df["close"] * x_df["volume"]
81
  x_timestamp = kdf["timestamps"]
 
82
  last = x_timestamp.iloc[-1]
83
  y_timestamp = pd.Series(pd.bdate_range(start=last + pd.Timedelta(days=1), periods=pred_days))
84
+ pred_df = _get_kronos().predict(
 
 
85
  df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
86
  pred_len=pred_days, T=1.0, top_p=0.9, sample_count=1, verbose=False,
87
  )
 
88
  last_close = float(kdf["close"].iloc[-1])
89
  pred_close = float(pred_df["close"].iloc[-1])
90
  pct = (pred_close - last_close) / last_close * 100.0
 
 
91
  return {
92
+ "status": "ok", "symbol": symbol, "model": KRONOS_MODEL_ID,
93
+ "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
94
+ "pct_change": round(pct, 3), "direction": _direction(pct),
95
+ "n_lookback": int(len(kdf)), "pred_days": pred_days,
 
 
 
 
 
 
96
  "pred_mean_close": round(float(pred_df["close"].mean()), 4),
97
  "pred_min_close": round(float(pred_df["close"].min()), 4),
98
  "pred_max_close": round(float(pred_df["close"].max()), 4),
 
101
  return {"status": "error", "error": f"{type(e).__name__}: {e}", "symbol": symbol}
102
 
103
 
104
+ # -----------------------------------------------------------------------------
105
+ # Chronos-bolt-tiny
106
+ # -----------------------------------------------------------------------------
107
+ CHRONOS_MODEL_ID = "amazon/chronos-bolt-tiny"
108
+ _chronos = None
109
+
110
+
111
+ def _get_chronos():
112
+ global _chronos
113
+ if _chronos is None:
114
+ from chronos import BaseChronosPipeline
115
+ _chronos = BaseChronosPipeline.from_pretrained(
116
+ CHRONOS_MODEL_ID, device_map="cpu", torch_dtype=torch.float32,
117
+ )
118
+ return _chronos
119
+
120
+
121
+ def forecast_chronos(symbol: str, lookback_days: int = 180, pred_days: int = 30) -> dict:
122
+ """Chronos-bolt-tiny forecast on close prices."""
123
+ symbol = (symbol or "").strip().upper()
124
+ if not symbol:
125
+ return {"status": "error", "error": "empty symbol"}
126
+ lookback_days, pred_days = _clamp(lookback_days, pred_days)
127
+ try:
128
+ kdf = _load_ohlc(symbol, lookback_days)
129
+ if len(kdf) < 32:
130
+ return {"status": "error", "error": f"insufficient data for {symbol}", "n_lookback": len(kdf)}
131
+ context = torch.tensor(kdf["close"].values, dtype=torch.float32)
132
+ quantiles, mean = _get_chronos().predict_quantiles(
133
+ context=context, prediction_length=pred_days, quantile_levels=[0.1, 0.5, 0.9],
134
+ )
135
+ median = quantiles[0, :, 1].cpu().numpy()
136
+ low = quantiles[0, :, 0].cpu().numpy()
137
+ high = quantiles[0, :, 2].cpu().numpy()
138
+ mean_np = mean[0].cpu().numpy()
139
+ last_close = float(kdf["close"].iloc[-1])
140
+ pred_close = float(median[-1])
141
+ pct = (pred_close - last_close) / last_close * 100.0
142
+ return {
143
+ "status": "ok", "symbol": symbol, "model": CHRONOS_MODEL_ID,
144
+ "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
145
+ "pct_change": round(pct, 3), "direction": _direction(pct),
146
+ "n_lookback": int(len(kdf)), "pred_days": pred_days,
147
+ "pred_mean_close": round(float(np.mean(mean_np)), 4),
148
+ "pred_low_close": round(float(low[-1]), 4),
149
+ "pred_high_close": round(float(high[-1]), 4),
150
+ }
151
+ except Exception as e:
152
+ return {"status": "error", "error": f"{type(e).__name__}: {e}", "symbol": symbol}
153
+
154
+
155
+ # -----------------------------------------------------------------------------
156
+ # TimesFM 2.5 (200M PyTorch)
157
+ # -----------------------------------------------------------------------------
158
+ TIMESFM_MODEL_ID = "google/timesfm-2.5-200m-pytorch"
159
+ _timesfm = None
160
+
161
+
162
+ def _get_timesfm():
163
+ global _timesfm
164
+ if _timesfm is None:
165
+ import timesfm
166
+ _timesfm = timesfm.TimesFm_2p5_200M_torch.from_pretrained(TIMESFM_MODEL_ID)
167
+ return _timesfm
168
+
169
+
170
+ def forecast_timesfm(symbol: str, lookback_days: int = 180, pred_days: int = 30) -> dict:
171
+ """TimesFM 2.5 (200M) forecast on close prices."""
172
+ symbol = (symbol or "").strip().upper()
173
+ if not symbol:
174
+ return {"status": "error", "error": "empty symbol"}
175
+ lookback_days, pred_days = _clamp(lookback_days, pred_days)
176
+ try:
177
+ kdf = _load_ohlc(symbol, lookback_days)
178
+ if len(kdf) < 32:
179
+ return {"status": "error", "error": f"insufficient data for {symbol}", "n_lookback": len(kdf)}
180
+ model = _get_timesfm()
181
+ point, _q = model.forecast(
182
+ inputs=[kdf["close"].values.astype(np.float32)],
183
+ freq=[0], horizon=pred_days,
184
+ )
185
+ pred = np.asarray(point[0])
186
+ last_close = float(kdf["close"].iloc[-1])
187
+ pred_close = float(pred[-1])
188
+ pct = (pred_close - last_close) / last_close * 100.0
189
+ return {
190
+ "status": "ok", "symbol": symbol, "model": TIMESFM_MODEL_ID,
191
+ "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
192
+ "pct_change": round(pct, 3), "direction": _direction(pct),
193
+ "n_lookback": int(len(kdf)), "pred_days": pred_days,
194
+ "pred_mean_close": round(float(np.mean(pred)), 4),
195
+ "pred_min_close": round(float(np.min(pred)), 4),
196
+ "pred_max_close": round(float(np.max(pred)), 4),
197
+ }
198
+ except Exception as e:
199
+ return {"status": "error", "error": f"{type(e).__name__}: {e}", "symbol": symbol}
200
+
201
+
202
+ # -----------------------------------------------------------------------------
203
+ # FinBERT — ProsusAI/finbert
204
+ # -----------------------------------------------------------------------------
205
+ FINBERT_MODEL_ID = "ProsusAI/finbert"
206
+ _finbert = None
207
+
208
+
209
+ def _get_finbert():
210
+ global _finbert
211
+ if _finbert is None:
212
+ from transformers import pipeline
213
+ _finbert = pipeline(
214
+ "text-classification", model=FINBERT_MODEL_ID,
215
+ device="cpu", top_k=None, truncation=True,
216
+ )
217
+ return _finbert
218
+
219
+
220
+ def score_sentiment(texts_json: str) -> dict:
221
+ """FinBERT scoring. Input: JSON array of strings (or newline-separated)."""
222
+ import json as _json
223
+ if not texts_json or not str(texts_json).strip():
224
+ return {"status": "error", "error": "empty input"}
225
+ try:
226
+ texts = _json.loads(texts_json)
227
+ if isinstance(texts, str):
228
+ texts = [texts]
229
+ if not isinstance(texts, list):
230
+ texts = [str(texts)]
231
+ except Exception:
232
+ texts = [t.strip() for t in str(texts_json).split("\n") if t.strip()]
233
+ texts = [str(t) for t in texts][:50]
234
+ if not texts:
235
+ return {"status": "error", "error": "no non-empty texts"}
236
+ try:
237
+ raw = _get_finbert()(texts)
238
+ pos_sum = neg_sum = neu_sum = 0.0
239
+ per = []
240
+ for item in raw:
241
+ entries = item if isinstance(item, list) else [item]
242
+ p = n = u = 0.0
243
+ for e in entries:
244
+ lbl = str(e.get("label", "")).lower()
245
+ sc = float(e.get("score", 0))
246
+ if lbl.startswith("pos"): p = sc
247
+ elif lbl.startswith("neg"): n = sc
248
+ elif lbl.startswith("neu"): u = sc
249
+ pos_sum += p; neg_sum += n; neu_sum += u
250
+ per.append({"pos": round(p, 4), "neg": round(n, 4), "neu": round(u, 4)})
251
+ n = len(texts)
252
+ return {
253
+ "status": "ok", "model": FINBERT_MODEL_ID, "n": n,
254
+ "net": round((pos_sum - neg_sum) / n, 4),
255
+ "pos": round(pos_sum / n, 4),
256
+ "neg": round(neg_sum / n, 4),
257
+ "neu": round(neu_sum / n, 4),
258
+ "per_text": per,
259
+ }
260
+ except Exception as e:
261
+ return {"status": "error", "error": f"{type(e).__name__}: {e}"}
262
+
263
+
264
+ # -----------------------------------------------------------------------------
265
+ # Gradio UI — 4 tabs, each exposes a named API for MCP discovery
266
+ # -----------------------------------------------------------------------------
267
+ with gr.Blocks(title="Investment OS inference") as demo:
268
+ gr.Markdown("# Investment OS — 3 TSFMs + FinBERT\nCPU-only, all lazy-loaded. Endpoints: `/forecast`, `/forecast_chronos`, `/forecast_timesfm`, `/score_sentiment`.")
269
+
270
+ with gr.Tab("Kronos"):
271
+ with gr.Row():
272
+ s1 = gr.Textbox(label="Symbol", value="VOO")
273
+ lb1 = gr.Slider(32, 500, value=180, step=1, label="Lookback days")
274
+ pd1 = gr.Slider(1, 90, value=30, step=1, label="Pred days")
275
+ gr.Button("Forecast").click(forecast, [s1, lb1, pd1], gr.JSON(), api_name="forecast")
276
+
277
+ with gr.Tab("Chronos-bolt-tiny"):
278
+ with gr.Row():
279
+ s2 = gr.Textbox(label="Symbol", value="VOO")
280
+ lb2 = gr.Slider(32, 500, value=180, step=1, label="Lookback days")
281
+ pd2 = gr.Slider(1, 90, value=30, step=1, label="Pred days")
282
+ gr.Button("Forecast").click(forecast_chronos, [s2, lb2, pd2], gr.JSON(), api_name="forecast_chronos")
283
+
284
+ with gr.Tab("TimesFM-2.5"):
285
+ with gr.Row():
286
+ s3 = gr.Textbox(label="Symbol", value="VOO")
287
+ lb3 = gr.Slider(32, 500, value=180, step=1, label="Lookback days")
288
+ pd3 = gr.Slider(1, 90, value=30, step=1, label="Pred days")
289
+ gr.Button("Forecast").click(forecast_timesfm, [s3, lb3, pd3], gr.JSON(), api_name="forecast_timesfm")
290
+
291
+ with gr.Tab("FinBERT"):
292
+ t4 = gr.Textbox(label="Texts (JSON array or newline-separated)",
293
+ value='["Strong Q4 beats expectations","Margin pressure ahead"]', lines=6)
294
+ gr.Button("Score").click(score_sentiment, [t4], gr.JSON(), api_name="score_sentiment")
295
+
296
 
297
  if __name__ == "__main__":
298
  demo.launch(server_name="0.0.0.0", server_port=7860, mcp_server=True)