vmjn commited on
Commit
6deee5a
·
verified ·
1 Parent(s): e46f520

v3: TimesFM 2.5 + TiRex + MOMENT + GDELT + Reddit

Browse files
Files changed (3) hide show
  1. Dockerfile +50 -24
  2. README.md +17 -15
  3. app.py +564 -358
Dockerfile CHANGED
@@ -1,38 +1,64 @@
1
  FROM python:3.11-slim
2
 
3
- # Non-root user for HF Spaces
 
 
 
 
4
  RUN useradd -m -u 1000 user
5
  USER user
6
- ENV HOME=/home/user \
7
- PATH=/home/user/.local/bin:$PATH \
8
- PYTHONUNBUFFERED=1 \
9
- GRADIO_SERVER_NAME=0.0.0.0 \
10
- GRADIO_SERVER_PORT=7860 \
11
- HF_HOME=/home/user/.cache/huggingface
12
-
13
  WORKDIR /home/user/app
14
 
15
- # CPU-only torch from PyTorch index (skip CUDA wheels, ~200MB vs ~2GB)
16
- RUN pip install --user --no-cache-dir --index-url https://download.pytorch.org/whl/cpu \
17
- torch==2.4.1
18
 
19
- # App deps all pinned to known-compatible versions
20
- RUN pip install --user --no-cache-dir \
21
- "gradio[mcp]==5.30.0" \
 
 
 
 
22
  "huggingface_hub>=0.27.0,<1.0" \
23
- "numpy>=1.26,<2.3" \
24
- "pandas>=2.1" \
 
 
 
 
 
 
 
 
 
25
  "yfinance>=0.2.50" \
26
  "curl_cffi>=0.7" \
27
- "websockets>=13.0" \
28
- "einops>=0.7" \
29
- "safetensors>=0.4" \
30
- "tqdm>=4.66" \
31
- "transformers==4.46.3" \
32
- "chronos-forecasting>=1.5.2" \
33
- "timesfm[torch]>=1.3.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- COPY --chown=user . /home/user/app
 
36
 
37
  EXPOSE 7860
38
  CMD ["python", "app.py"]
 
1
  FROM python:3.11-slim
2
 
3
+ RUN apt-get update && apt-get install -y --no-install-recommends \
4
+ git build-essential curl wget \
5
+ && rm -rf /var/lib/apt/lists/*
6
+
7
+ # Create user matching HF Space conventions
8
  RUN useradd -m -u 1000 user
9
  USER user
10
+ ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
 
 
 
 
 
 
11
  WORKDIR /home/user/app
12
 
13
+ # Disable TiRex CUDA kernels (we're on CPU-only Space)
14
+ ENV TIREX_NO_CUDA=1 XLSTM_USE_CUDA_KERNELS=0
 
15
 
16
+ # Core: Torch CPU (keep 2.4.1 Kronos meta-tensor compatibility)
17
+ RUN pip install --no-cache-dir --upgrade pip && \
18
+ pip install --no-cache-dir torch==2.4.1 --index-url https://download.pytorch.org/whl/cpu
19
+
20
+ # Transformers bumped to 4.55.0 for TimesFm2_5 support
21
+ RUN pip install --no-cache-dir \
22
+ "transformers==4.55.0" \
23
  "huggingface_hub>=0.27.0,<1.0" \
24
+ "gradio[mcp]==5.30.0" \
25
+ "websockets>=13" \
26
+ "einops" \
27
+ "safetensors" \
28
+ "pandas" \
29
+ "numpy<2.3"
30
+
31
+ # Chronos + TimesFM (fallback) + yfinance + sentiment
32
+ RUN pip install --no-cache-dir \
33
+ "chronos-forecasting>=1.5.2" \
34
+ "timesfm[torch]>=1.3.0" \
35
  "yfinance>=0.2.50" \
36
  "curl_cffi>=0.7" \
37
+ "sentencepiece" \
38
+ "tokenizers"
39
+
40
+ # NEW: MOMENT (anomaly detection)
41
+ RUN pip install --no-cache-dir "momentfm>=0.1.4"
42
+
43
+ # NEW: GDELT (news streaming)
44
+ RUN pip install --no-cache-dir "gdeltdoc>=1.5.0"
45
+
46
+ # NEW: requests for Reddit
47
+ RUN pip install --no-cache-dir "requests>=2.31" "beautifulsoup4>=4.12"
48
+
49
+ # NEW: TiRex (xLSTM TSFM) — install from git, CPU experimental mode
50
+ # Use || true so build doesn't fail if tirex install hits issues — endpoint will error gracefully
51
+ RUN pip install --no-cache-dir git+https://github.com/NX-AI/tirex.git || \
52
+ echo "WARNING: TiRex install failed — /forecast_tirex will return error"
53
+
54
+ # Kronos model repo
55
+ USER user
56
+ RUN mkdir -p /home/user/app/model
57
+ COPY --chown=user ./model /home/user/app/model
58
+ COPY --chown=user ./app.py /home/user/app/app.py
59
 
60
+ # Pre-create HF cache dirs with right perms
61
+ RUN mkdir -p /home/user/.cache/huggingface
62
 
63
  EXPOSE 7860
64
  CMD ["python", "app.py"]
README.md CHANGED
@@ -1,22 +1,24 @@
1
  ---
2
- title: Kronos Forecast
3
- emoji: 📈
4
- colorFrom: indigo
5
- colorTo: purple
6
  sdk: docker
7
  app_port: 7860
8
- pinned: false
9
  license: mit
10
- tags:
11
- - mcp-server
12
- - finance
13
- - forecasting
14
- - kronos
15
  ---
16
 
17
- Kronos-small financial forecast wrapper exposing Gradio MCP endpoints for the Investment OS.
18
 
19
- - Model: `NeoQuasar/Kronos-small` (24.7M params, CPU-friendly)
20
- - Tokenizer: `NeoQuasar/Kronos-Tokenizer-base`
21
- - Data: yfinance OHLC
22
- - SDK: Docker (Python 3.11, pinned deps)
 
 
 
 
 
 
 
1
  ---
2
+ title: Kronos Forecast Multi-Model
3
+ emoji: 🦜
4
+ colorFrom: purple
5
+ colorTo: indigo
6
  sdk: docker
7
  app_port: 7860
8
+ pinned: true
9
  license: mit
10
+ short_description: TSFM + Sentiment + News + Reddit ensemble
 
 
 
 
11
  ---
12
 
13
+ # Investment OS Multi-Model Space
14
 
15
+ 9 endpoints exposed via Gradio + MCP:
16
+ - `/forecast` — Kronos (OHLCV finance TSFM)
17
+ - `/forecast_chronos` Chronos-bolt-tiny (generic TSFM)
18
+ - `/forecast_timesfm` TimesFM 2.5 (Google)
19
+ - `/forecast_tirex` — TiRex (NX-AI xLSTM TSFM)
20
+ - `/anomaly_moment` — MOMENT-1-large (anomaly detection)
21
+ - `/score_sentiment` — FinBERT text
22
+ - `/score_sentiment_for_symbol` — FinBERT × yfinance news
23
+ - `/news_gdelt_for_symbol` — FinBERT × GDELT global news
24
+ - `/reddit_sentiment_for_symbol` — FinBERT × Reddit retail sentiment
app.py CHANGED
@@ -1,414 +1,620 @@
1
- """Multi-model Investment OS inference Space.
 
2
 
3
- Endpoints (Gradio + MCP):
4
- - /forecast — Kronos-small (finance-native candlestick foundation model)
5
- - /forecast_chronos — amazon/chronos-bolt-tiny (generic TSFM, CPU-fast)
6
- - /forecast_timesfm — google/timesfm-2.5-200m-pytorch (Google TSFM)
7
- - /score_sentiment — ProsusAI/finbert (financial sentiment)
8
 
9
- All models lazy-loaded on first call. CPU-only.
10
- """
11
- from __future__ import annotations
 
 
12
 
13
  import numpy as np
14
  import pandas as pd
15
  import torch
16
- import yfinance as yf
17
  import gradio as gr
18
 
 
 
 
19
 
20
- # -----------------------------------------------------------------------------
21
- # Shared: yfinance OHLC loader
22
- # -----------------------------------------------------------------------------
23
-
24
- def _load_ohlc(symbol: str, lookback_days: int) -> pd.DataFrame:
25
- def _tidy(df):
26
- if df is None or df.empty:
27
- return pd.DataFrame()
28
- if isinstance(df.columns, pd.MultiIndex):
29
- df.columns = df.columns.get_level_values(0)
30
- df = df.reset_index()
31
- df.columns = [str(c).lower() for c in df.columns]
32
- if "date" not in df.columns and "datetime" in df.columns:
33
- df["date"] = df["datetime"]
34
- if "date" not in df.columns:
35
- return pd.DataFrame()
36
- df["timestamps"] = pd.to_datetime(df["date"])
37
- keep = ["timestamps", "open", "high", "low", "close", "volume"]
38
- df = df[[c for c in keep if c in df.columns]].dropna().tail(lookback_days).reset_index(drop=True)
39
- return df
40
-
41
- # Primary: yf.download. Fallback: yf.Ticker().history(). Uses curl_cffi chrome impersonation if available.
42
- session = None
43
  try:
44
- from curl_cffi import requests as cureq
45
- session = cureq.Session(impersonate="chrome")
46
  except Exception:
47
  session = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- period = f"{lookback_days + 10}d"
50
- try:
51
- df = yf.download(symbol, period=period, interval="1d",
52
- progress=False, auto_adjust=False,
53
- session=session) if session else \
54
- yf.download(symbol, period=period, interval="1d",
55
- progress=False, auto_adjust=False)
56
- out = _tidy(df)
57
- if len(out) >= 32:
58
- return out
59
- except Exception:
60
- pass
61
 
62
- # Fallback path
 
 
 
 
 
 
 
 
 
 
 
 
63
  try:
64
- t = yf.Ticker(symbol, session=session) if session else yf.Ticker(symbol)
65
- df = t.history(period=period, interval="1d", auto_adjust=False)
66
- return _tidy(df)
67
- except Exception:
68
- return pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
- def _direction(pct: float) -> int:
72
- return 1 if pct > 0.5 else (-1 if pct < -0.5 else 0)
 
 
73
 
74
 
75
- def _clamp(lb, pd_, min_lb=32, max_lb=500, max_pred=90):
76
- return (int(max(min_lb, min(int(lb or 180), max_lb))),
77
- int(max(1, min(int(pd_ or 30), max_pred))))
 
 
 
 
78
 
79
 
80
- # -----------------------------------------------------------------------------
81
- # Kronos — NeoQuasar/Kronos-small
82
- # -----------------------------------------------------------------------------
83
- from model import Kronos, KronosTokenizer, KronosPredictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- KRONOS_MODEL_ID = "NeoQuasar/Kronos-small"
86
- KRONOS_TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
87
- _kronos = None
 
88
 
89
 
90
- def _get_kronos():
91
- global _kronos
92
- if _kronos is None:
93
- tok = KronosTokenizer.from_pretrained(KRONOS_TOKENIZER_ID)
94
- mdl = Kronos.from_pretrained(KRONOS_MODEL_ID)
95
- _kronos = KronosPredictor(mdl, tok, device="cpu", max_context=512)
96
- return _kronos
97
-
98
-
99
- def forecast(symbol: str, lookback_days: int = 180, pred_days: int = 30) -> dict:
100
- """Kronos-small (finance-native) forecast. Returns direction + % change."""
101
- symbol = (symbol or "").strip().upper()
102
- if not symbol:
103
- return {"status": "error", "error": "empty symbol"}
104
- lookback_days, pred_days = _clamp(lookback_days, pred_days)
 
 
 
 
 
 
 
105
  try:
106
- kdf = _load_ohlc(symbol, lookback_days)
107
- if len(kdf) < 32:
108
- return {"status": "error", "error": f"insufficient data for {symbol}", "n_lookback": len(kdf)}
109
- x_df = kdf[["open", "high", "low", "close", "volume"]].copy()
110
- x_df["amount"] = x_df["close"] * x_df["volume"]
111
- x_timestamp = kdf["timestamps"]
112
- last = x_timestamp.iloc[-1]
113
- y_timestamp = pd.Series(pd.bdate_range(start=last + pd.Timedelta(days=1), periods=pred_days))
114
- pred_df = _get_kronos().predict(
115
- df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
116
- pred_len=pred_days, T=1.0, top_p=0.9, sample_count=1, verbose=False,
117
- )
118
- last_close = float(kdf["close"].iloc[-1])
119
- pred_close = float(pred_df["close"].iloc[-1])
120
- pct = (pred_close - last_close) / last_close * 100.0
121
- return {
122
- "status": "ok", "symbol": symbol, "model": KRONOS_MODEL_ID,
123
- "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
124
- "pct_change": round(pct, 3), "direction": _direction(pct),
125
- "n_lookback": int(len(kdf)), "pred_days": pred_days,
126
- "pred_mean_close": round(float(pred_df["close"].mean()), 4),
127
- "pred_min_close": round(float(pred_df["close"].min()), 4),
128
- "pred_max_close": round(float(pred_df["close"].max()), 4),
129
- }
 
 
 
 
130
  except Exception as e:
131
- return {"status": "error", "error": f"{type(e).__name__}: {e}", "symbol": symbol}
132
 
133
 
134
- # -----------------------------------------------------------------------------
135
- # Chronos-bolt-tiny
136
- # -----------------------------------------------------------------------------
137
- CHRONOS_MODEL_ID = "amazon/chronos-bolt-tiny"
138
- _chronos = None
139
 
140
 
141
- def _get_chronos():
142
- global _chronos
143
- if _chronos is None:
144
- from chronos import BaseChronosPipeline
145
- _chronos = BaseChronosPipeline.from_pretrained(
146
- CHRONOS_MODEL_ID, device_map="cpu", torch_dtype=torch.float32,
147
- )
148
- return _chronos
149
-
150
-
151
- def forecast_chronos(symbol: str, lookback_days: int = 180, pred_days: int = 30) -> dict:
152
- """Chronos-bolt-tiny forecast on close prices."""
153
- symbol = (symbol or "").strip().upper()
154
- if not symbol:
155
- return {"status": "error", "error": "empty symbol"}
156
- lookback_days, pred_days = _clamp(lookback_days, pred_days)
157
  try:
158
- kdf = _load_ohlc(symbol, lookback_days)
159
- if len(kdf) < 32:
160
- return {"status": "error", "error": f"insufficient data for {symbol}", "n_lookback": len(kdf)}
161
- context = torch.tensor(kdf["close"].values, dtype=torch.float32)
162
- quantiles, mean = _get_chronos().predict_quantiles(
163
- inputs=context, prediction_length=pred_days, quantile_levels=[0.1, 0.5, 0.9],
164
- )
165
- median = quantiles[0, :, 1].cpu().numpy()
166
- low = quantiles[0, :, 0].cpu().numpy()
167
- high = quantiles[0, :, 2].cpu().numpy()
168
- mean_np = mean[0].cpu().numpy()
169
- last_close = float(kdf["close"].iloc[-1])
170
- pred_close = float(median[-1])
171
- pct = (pred_close - last_close) / last_close * 100.0
172
- return {
173
- "status": "ok", "symbol": symbol, "model": CHRONOS_MODEL_ID,
174
- "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
175
- "pct_change": round(pct, 3), "direction": _direction(pct),
176
- "n_lookback": int(len(kdf)), "pred_days": pred_days,
177
- "pred_mean_close": round(float(np.mean(mean_np)), 4),
178
- "pred_low_close": round(float(low[-1]), 4),
179
- "pred_high_close": round(float(high[-1]), 4),
180
- }
 
 
 
 
 
 
181
  except Exception as e:
182
- return {"status": "error", "error": f"{type(e).__name__}: {e}", "symbol": symbol}
183
 
184
 
185
- # -----------------------------------------------------------------------------
186
- # TimesFM 2.5 (200M PyTorch)
187
- # -----------------------------------------------------------------------------
188
- TIMESFM_MODEL_ID = "google/timesfm-2.0-500m-pytorch"
189
- _timesfm = None
190
 
191
 
192
- def _get_timesfm():
193
- global _timesfm
194
- if _timesfm is None:
195
- import timesfm
196
- _timesfm = timesfm.TimesFm(
197
- hparams=timesfm.TimesFmHparams(
198
- backend="torch",
199
- per_core_batch_size=1,
200
- horizon_len=128,
201
- num_layers=50,
202
- use_positional_embedding=False,
203
- context_len=2048,
204
- ),
205
- checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=TIMESFM_MODEL_ID),
206
- )
207
- return _timesfm
208
-
209
-
210
- def forecast_timesfm(symbol: str, lookback_days: int = 180, pred_days: int = 30) -> dict:
211
- """TimesFM 2.5 (200M) forecast on close prices."""
212
- symbol = (symbol or "").strip().upper()
213
- if not symbol:
214
- return {"status": "error", "error": "empty symbol"}
215
- lookback_days, pred_days = _clamp(lookback_days, pred_days)
216
  try:
217
- kdf = _load_ohlc(symbol, lookback_days)
218
- if len(kdf) < 32:
219
- return {"status": "error", "error": f"insufficient data for {symbol}", "n_lookback": len(kdf)}
220
- model = _get_timesfm()
221
- point, _q = model.forecast(
222
- inputs=[kdf["close"].values.astype(np.float32)],
223
- freq=[0],
224
- )
225
- pred = np.asarray(point[0])[:pred_days]
226
- last_close = float(kdf["close"].iloc[-1])
227
- pred_close = float(pred[-1])
228
- pct = (pred_close - last_close) / last_close * 100.0
229
- return {
230
- "status": "ok", "symbol": symbol, "model": TIMESFM_MODEL_ID,
231
- "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
232
- "pct_change": round(pct, 3), "direction": _direction(pct),
233
- "n_lookback": int(len(kdf)), "pred_days": pred_days,
234
- "pred_mean_close": round(float(np.mean(pred)), 4),
235
- "pred_min_close": round(float(np.min(pred)), 4),
236
- "pred_max_close": round(float(np.max(pred)), 4),
237
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  except Exception as e:
239
- return {"status": "error", "error": f"{type(e).__name__}: {e}", "symbol": symbol}
240
 
241
 
242
- # -----------------------------------------------------------------------------
243
- # FinBERT ProsusAI/finbert
244
- # -----------------------------------------------------------------------------
245
- FINBERT_MODEL_ID = "peejm/finbert-financial-sentiment"
246
- _finbert = None
247
 
248
 
249
  def _get_finbert():
250
- global _finbert
251
- if _finbert is None:
252
- from transformers import pipeline
253
- _finbert = pipeline(
254
- "text-classification", model=FINBERT_MODEL_ID,
255
- device="cpu", top_k=None, truncation=True,
256
- )
257
- return _finbert
258
-
259
-
260
- def score_sentiment(texts_json: str) -> dict:
261
- """FinBERT scoring. Input: JSON array of strings (or newline-separated)."""
262
- import json as _json
263
- if not texts_json or not str(texts_json).strip():
264
- return {"status": "error", "error": "empty input"}
265
- try:
266
- texts = _json.loads(texts_json)
267
- if isinstance(texts, str):
268
- texts = [texts]
269
- if not isinstance(texts, list):
270
- texts = [str(texts)]
271
- except Exception:
272
- texts = [t.strip() for t in str(texts_json).split("\n") if t.strip()]
273
- texts = [str(t) for t in texts][:50]
274
  if not texts:
275
- return {"status": "error", "error": "no non-empty texts"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  try:
277
- raw = _get_finbert()(texts)
278
- pos_sum = neg_sum = neu_sum = 0.0
279
- per = []
280
- for item in raw:
281
- entries = item if isinstance(item, list) else [item]
282
- p = n = u = 0.0
283
- for e in entries:
284
- lbl = str(e.get("label", "")).lower()
285
- sc = float(e.get("score", 0))
286
- if lbl.startswith("pos"): p = sc
287
- elif lbl.startswith("neg"): n = sc
288
- elif lbl.startswith("neu"): u = sc
289
- pos_sum += p; neg_sum += n; neu_sum += u
290
- per.append({"pos": round(p, 4), "neg": round(n, 4), "neu": round(u, 4)})
291
- n = len(texts)
292
- return {
293
- "status": "ok", "model": FINBERT_MODEL_ID, "n": n,
294
- "net": round((pos_sum - neg_sum) / n, 4),
295
- "pos": round(pos_sum / n, 4),
296
- "neg": round(neg_sum / n, 4),
297
- "neu": round(neu_sum / n, 4),
298
- "per_text": per,
299
- }
300
  except Exception as e:
301
- return {"status": "error", "error": f"{type(e).__name__}: {e}"}
302
 
303
 
304
- def score_sentiment_for_symbol(symbol: str, max_items: int = 20) -> dict:
305
- """Fetch recent news for a symbol via yfinance (free, no API key) and score via FinBERT.
306
-
307
- Returns aggregated sentiment + direction signal (+1/−1/0) suitable for ensemble voting.
308
- """
309
- symbol = (symbol or "").strip().upper()
310
- if not symbol:
311
- return {"status": "error", "error": "empty symbol"}
312
- max_items = int(max(1, min(int(max_items or 20), 50)))
313
- # yfinance news via curl_cffi session if available
314
- try:
315
- from curl_cffi import requests as cureq
316
- session = cureq.Session(impersonate="chrome")
317
- except Exception:
318
- session = None
319
  try:
 
 
 
 
 
 
320
  t = yf.Ticker(symbol, session=session) if session else yf.Ticker(symbol)
321
- news = t.news or []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  except Exception as e:
323
- return {"status": "error", "error": f"yfinance news: {e}", "symbol": symbol}
324
- if not news:
325
- return {"status": "error", "error": f"no news for {symbol}", "symbol": symbol, "direction": 0}
326
-
327
- # yfinance news items have {title, link, publisher, providerPublishTime, ...}
328
- # or newer shape {content: {title, summary, ...}}
329
- titles = []
330
- for item in news[:max_items]:
331
- title = item.get("title") or ""
332
- summary = item.get("summary") or ""
333
- content = item.get("content") or {}
334
- if isinstance(content, dict):
335
- title = title or content.get("title", "")
336
- summary = summary or content.get("summary", "")
337
- text = (title + ". " + summary).strip(". ").strip()
338
- if text:
339
- titles.append(text[:512])
340
- if not titles:
341
- return {"status": "error", "error": "no parseable news titles", "symbol": symbol, "direction": 0}
342
-
343
- # Score via FinBERT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  try:
345
- raw = _get_finbert()(titles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  except Exception as e:
347
- return {"status": "error", "error": f"finbert: {e}", "symbol": symbol}
348
- pos_sum = neg_sum = neu_sum = 0.0
349
- for item in raw:
350
- entries = item if isinstance(item, list) else [item]
351
- p = n = u = 0.0
352
- for e in entries:
353
- lbl = str(e.get("label", "")).lower()
354
- sc = float(e.get("score", 0))
355
- if lbl.startswith("pos"): p = sc
356
- elif lbl.startswith("neg"): n = sc
357
- elif lbl.startswith("neu"): u = sc
358
- pos_sum += p; neg_sum += n; neu_sum += u
359
- n = len(titles)
360
- net = (pos_sum - neg_sum) / n
361
- # Direction threshold: 0.15 net sentiment ~= meaningful tilt
362
- direction = 1 if net > 0.15 else (-1 if net < -0.15 else 0)
363
- return {
364
- "status": "ok", "symbol": symbol, "model": FINBERT_MODEL_ID, "n": n,
365
- "net": round(net, 4),
366
- "pos": round(pos_sum / n, 4),
367
- "neg": round(neg_sum / n, 4),
368
- "neu": round(neu_sum / n, 4),
369
- "direction": direction,
370
- }
371
-
372
-
373
- # -----------------------------------------------------------------------------
374
- # Gradio UI — 4 tabs, each exposes a named API for MCP discovery
375
- # -----------------------------------------------------------------------------
376
- with gr.Blocks(title="Investment OS inference") as demo:
377
- gr.Markdown("# Investment OS — 3 TSFMs + FinBERT\nCPU-only, all lazy-loaded. Endpoints: `/forecast`, `/forecast_chronos`, `/forecast_timesfm`, `/score_sentiment`.")
378
-
379
- with gr.Tab("Kronos"):
380
- with gr.Row():
381
- s1 = gr.Textbox(label="Symbol", value="VOO")
382
- lb1 = gr.Slider(32, 500, value=180, step=1, label="Lookback days")
383
- pd1 = gr.Slider(1, 90, value=30, step=1, label="Pred days")
384
- gr.Button("Forecast").click(forecast, [s1, lb1, pd1], gr.JSON(), api_name="forecast")
385
-
386
- with gr.Tab("Chronos-bolt-tiny"):
387
- with gr.Row():
388
- s2 = gr.Textbox(label="Symbol", value="VOO")
389
- lb2 = gr.Slider(32, 500, value=180, step=1, label="Lookback days")
390
- pd2 = gr.Slider(1, 90, value=30, step=1, label="Pred days")
391
- gr.Button("Forecast").click(forecast_chronos, [s2, lb2, pd2], gr.JSON(), api_name="forecast_chronos")
392
-
393
- with gr.Tab("TimesFM-2.5"):
394
- with gr.Row():
395
- s3 = gr.Textbox(label="Symbol", value="VOO")
396
- lb3 = gr.Slider(32, 500, value=180, step=1, label="Lookback days")
397
- pd3 = gr.Slider(1, 90, value=30, step=1, label="Pred days")
398
- gr.Button("Forecast").click(forecast_timesfm, [s3, lb3, pd3], gr.JSON(), api_name="forecast_timesfm")
399
-
400
- with gr.Tab("FinBERT (text)"):
401
- t4 = gr.Textbox(label="Texts (JSON array or newline-separated)",
402
- value='["Strong Q4 beats expectations","Margin pressure ahead"]', lines=6)
403
- gr.Button("Score").click(score_sentiment, [t4], gr.JSON(), api_name="score_sentiment")
404
-
405
- with gr.Tab("FinBERT (by symbol)"):
406
- with gr.Row():
407
- s5 = gr.Textbox(label="Symbol", value="AAPL")
408
- n5 = gr.Slider(1, 50, value=20, step=1, label="Max news items")
409
- gr.Button("Fetch news + Score").click(score_sentiment_for_symbol, [s5, n5], gr.JSON(),
410
- api_name="score_sentiment_for_symbol")
 
 
 
 
 
 
 
 
411
 
412
 
413
  if __name__ == "__main__":
414
- demo.launch(server_name="0.0.0.0", server_port=7860, mcp_server=True)
 
1
+ """Kronos + Chronos + TimesFM + TiRex + MOMENT + FinBERT + GDELT + Reddit — Investment OS Space."""
2
+ from __future__ import annotations
3
 
4
+ import os, sys, time, json, traceback, threading, warnings
5
+ from typing import List, Optional, Tuple, Dict, Any
 
 
 
6
 
7
+ warnings.filterwarnings("ignore")
8
+ os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
9
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
10
+ # Disable TiRex custom CUDA kernels (we're on CPU)
11
+ os.environ.setdefault("TIREX_NO_CUDA", "1")
12
 
13
  import numpy as np
14
  import pandas as pd
15
  import torch
 
16
  import gradio as gr
17
 
18
+ # ============================================================
19
+ # Shared yfinance OHLC loader (new session each call to avoid race)
20
+ # ============================================================
21
 
22
+ def _load_ohlc(symbol: str, lookback: int = 180) -> pd.DataFrame:
23
+ import yfinance as yf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  try:
25
+ from curl_cffi import requests as cffi_requests
26
+ session = cffi_requests.Session(impersonate="chrome")
27
  except Exception:
28
  session = None
29
+ end = pd.Timestamp.utcnow().tz_localize(None)
30
+ start = end - pd.Timedelta(days=int(lookback * 2.2)) # account for weekends/holidays
31
+ kwargs = dict(start=start.strftime("%Y-%m-%d"), end=(end + pd.Timedelta(days=1)).strftime("%Y-%m-%d"),
32
+ interval="1d", progress=False, auto_adjust=False, actions=False, threads=False)
33
+ if session is not None:
34
+ kwargs["session"] = session
35
+ df = yf.download(symbol, **kwargs)
36
+ if df is None or len(df) == 0:
37
+ raise RuntimeError(f"No data for {symbol}")
38
+ if isinstance(df.columns, pd.MultiIndex):
39
+ df.columns = df.columns.get_level_values(0)
40
+ df = df.dropna().tail(lookback).reset_index()
41
+ need = {"Open", "High", "Low", "Close", "Volume"}
42
+ if not need.issubset(set(df.columns)):
43
+ raise RuntimeError(f"Missing columns for {symbol}: got {list(df.columns)}")
44
+ return df
45
+
46
+
47
+ # ============================================================
48
+ # Model 1: Kronos (finance-native OHLCV foundation model)
49
+ # ============================================================
50
+ _kronos_cache = {"model": None, "tok": None, "pred": None, "lock": threading.Lock()}
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ def _get_kronos():
54
+ with _kronos_cache["lock"]:
55
+ if _kronos_cache["pred"] is None:
56
+ from model import Kronos, KronosTokenizer, KronosPredictor
57
+ tok = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
58
+ mdl = Kronos.from_pretrained("NeoQuasar/Kronos-small")
59
+ _kronos_cache["tok"] = tok
60
+ _kronos_cache["model"] = mdl
61
+ _kronos_cache["pred"] = KronosPredictor(model=mdl, tokenizer=tok, device="cpu", max_context=512)
62
+ return _kronos_cache["pred"]
63
+
64
+
65
+ def forecast(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict:
66
  try:
67
+ df = _load_ohlc(symbol, lookback)
68
+ pred = _get_kronos()
69
+ x_df = df[["Open", "High", "Low", "Close", "Volume"]].copy()
70
+ x_ts = pd.to_datetime(df["Date"])
71
+ last = x_ts.iloc[-1]
72
+ y_ts = pd.date_range(start=last + pd.Timedelta(days=1), periods=pred_days, freq="B")
73
+ out = pred.predict(df=x_df, x_timestamp=x_ts, y_timestamp=y_ts, pred_len=pred_days, T=1.0, top_p=0.9, sample_count=1, verbose=False)
74
+ last_close = float(x_df["Close"].iloc[-1])
75
+ pred_close = float(out["close"].iloc[-1])
76
+ mean_close = float(out["close"].mean())
77
+ min_close = float(out["close"].min())
78
+ max_close = float(out["close"].max())
79
+ pct = (pred_close - last_close) / last_close * 100
80
+ direction = 1 if pct > 2 else (-1 if pct < -2 else 0)
81
+ return {"status": "ok", "symbol": symbol, "model": "NeoQuasar/Kronos-small",
82
+ "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
83
+ "pct_change": round(pct, 3), "direction": direction,
84
+ "n_lookback": int(len(x_df)), "pred_days": pred_days,
85
+ "pred_mean_close": round(mean_close, 4), "pred_min_close": round(min_close, 4),
86
+ "pred_max_close": round(max_close, 4)}
87
+ except Exception as e:
88
+ return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
89
 
90
 
91
+ # ============================================================
92
+ # Model 2: Chronos-bolt-tiny (generic TSFM)
93
+ # ============================================================
94
+ _chronos_cache = {"pipe": None, "lock": threading.Lock()}
95
 
96
 
97
+ def _get_chronos():
98
+ with _chronos_cache["lock"]:
99
+ if _chronos_cache["pipe"] is None:
100
+ from chronos import BaseChronosPipeline
101
+ _chronos_cache["pipe"] = BaseChronosPipeline.from_pretrained(
102
+ "amazon/chronos-bolt-tiny", device_map="cpu", torch_dtype=torch.float32)
103
+ return _chronos_cache["pipe"]
104
 
105
 
106
+ def forecast_chronos(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict:
107
+ try:
108
+ df = _load_ohlc(symbol, lookback)
109
+ closes = df["Close"].values.astype(np.float32)
110
+ pipe = _get_chronos()
111
+ ctx = torch.tensor(closes, dtype=torch.float32)
112
+ quantiles, mean = pipe.predict_quantiles(context=ctx, prediction_length=int(pred_days),
113
+ quantile_levels=[0.1, 0.5, 0.9])
114
+ mean_pred = mean[0].numpy()
115
+ low_pred = quantiles[0, :, 0].numpy()
116
+ high_pred = quantiles[0, :, 2].numpy()
117
+ last_close = float(closes[-1])
118
+ pred_close = float(mean_pred[-1])
119
+ pct = (pred_close - last_close) / last_close * 100
120
+ direction = 1 if pct > 2 else (-1 if pct < -2 else 0)
121
+ return {"status": "ok", "symbol": symbol, "model": "amazon/chronos-bolt-tiny",
122
+ "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
123
+ "pct_change": round(pct, 3), "direction": direction,
124
+ "n_lookback": int(len(closes)), "pred_days": int(pred_days),
125
+ "pred_mean_close": round(float(mean_pred.mean()), 4),
126
+ "pred_low_close": round(float(low_pred.min()), 4),
127
+ "pred_high_close": round(float(high_pred.max()), 4)}
128
+ except Exception as e:
129
+ return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
130
+
131
 
132
+ # ============================================================
133
+ # Model 3: TimesFM 2.5 via transformers (UPGRADED from 2.0)
134
+ # ============================================================
135
+ _timesfm_cache = {"model": None, "lock": threading.Lock()}
136
 
137
 
138
+ def _get_timesfm():
139
+ with _timesfm_cache["lock"]:
140
+ if _timesfm_cache["model"] is None:
141
+ try:
142
+ from transformers import TimesFm2_5ModelForPrediction
143
+ m = TimesFm2_5ModelForPrediction.from_pretrained(
144
+ "google/timesfm-2.5-200m-transformers")
145
+ m = m.to(torch.float32).eval()
146
+ _timesfm_cache["model"] = m
147
+ _timesfm_cache["version"] = "2.5"
148
+ except Exception:
149
+ # Fallback to 2.0 if 2.5 unavailable in transformers version
150
+ from transformers import TimesFmModelForPrediction
151
+ m = TimesFmModelForPrediction.from_pretrained(
152
+ "google/timesfm-2.0-500m-pytorch")
153
+ m = m.to(torch.float32).eval()
154
+ _timesfm_cache["model"] = m
155
+ _timesfm_cache["version"] = "2.0"
156
+ return _timesfm_cache["model"], _timesfm_cache["version"]
157
+
158
+
159
+ def forecast_timesfm(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict:
160
  try:
161
+ df = _load_ohlc(symbol, lookback)
162
+ closes = df["Close"].values.astype(np.float32)
163
+ model, ver = _get_timesfm()
164
+ past = [torch.tensor(closes, dtype=torch.float32)]
165
+ with torch.no_grad():
166
+ if ver == "2.5":
167
+ outputs = model(past_values=past, forecast_context_len=1024)
168
+ mean_pred = outputs.mean_predictions[0].float().cpu().numpy()
169
+ else:
170
+ # v2.0 transformers API
171
+ freq = torch.tensor([0], dtype=torch.long)
172
+ outputs = model(past_values=past, freq=freq, return_dict=True)
173
+ mean_pred = outputs.mean_predictions[0].float().cpu().numpy()
174
+ # Slice to pred_days
175
+ horizon = min(int(pred_days), len(mean_pred))
176
+ mean_pred = mean_pred[:horizon]
177
+ last_close = float(closes[-1])
178
+ pred_close = float(mean_pred[-1])
179
+ pct = (pred_close - last_close) / last_close * 100
180
+ direction = 1 if pct > 2 else (-1 if pct < -2 else 0)
181
+ return {"status": "ok", "symbol": symbol,
182
+ "model": f"google/timesfm-{ver}",
183
+ "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
184
+ "pct_change": round(pct, 3), "direction": direction,
185
+ "n_lookback": int(len(closes)), "pred_days": horizon,
186
+ "pred_mean_close": round(float(mean_pred.mean()), 4),
187
+ "pred_min_close": round(float(mean_pred.min()), 4),
188
+ "pred_max_close": round(float(mean_pred.max()), 4)}
189
  except Exception as e:
190
+ return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
191
 
192
 
193
+ # ============================================================
194
+ # Model 4 (NEW): TiRex (35M xLSTM TSFM, CPU experimental)
195
+ # ============================================================
196
+ _tirex_cache = {"model": None, "lock": threading.Lock()}
 
197
 
198
 
199
+ def _get_tirex():
200
+ with _tirex_cache["lock"]:
201
+ if _tirex_cache["model"] is None:
202
+ from tirex import load_model
203
+ _tirex_cache["model"] = load_model("NX-AI/TiRex")
204
+ return _tirex_cache["model"]
205
+
206
+
207
+ def forecast_tirex(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict:
 
 
 
 
 
 
 
208
  try:
209
+ df = _load_ohlc(symbol, lookback)
210
+ closes = df["Close"].values.astype(np.float32)
211
+ model = _get_tirex()
212
+ # TiRex expects (batch, seq_len)
213
+ ctx = torch.tensor(closes, dtype=torch.float32).unsqueeze(0)
214
+ with torch.no_grad():
215
+ result = model.forecast(context=ctx, prediction_length=int(pred_days))
216
+ # TiRex returns (quantiles, mean) tuple in newer versions
217
+ if isinstance(result, tuple) and len(result) == 2:
218
+ _, mean_pred = result
219
+ else:
220
+ mean_pred = result
221
+ mean_arr = mean_pred[0].float().cpu().numpy() if hasattr(mean_pred, "cpu") else np.asarray(mean_pred)[0]
222
+ # Check for NaN (TiRex CPU may degrade)
223
+ if np.isnan(mean_arr).any():
224
+ return {"status": "error", "symbol": symbol,
225
+ "error": "TiRex returned NaN (CPU mode is experimental)",
226
+ "model": "NX-AI/TiRex"}
227
+ last_close = float(closes[-1])
228
+ pred_close = float(mean_arr[-1])
229
+ pct = (pred_close - last_close) / last_close * 100
230
+ direction = 1 if pct > 2 else (-1 if pct < -2 else 0)
231
+ return {"status": "ok", "symbol": symbol, "model": "NX-AI/TiRex",
232
+ "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
233
+ "pct_change": round(pct, 3), "direction": direction,
234
+ "n_lookback": int(len(closes)), "pred_days": int(pred_days),
235
+ "pred_mean_close": round(float(mean_arr.mean()), 4),
236
+ "pred_min_close": round(float(mean_arr.min()), 4),
237
+ "pred_max_close": round(float(mean_arr.max()), 4)}
238
  except Exception as e:
239
+ return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
240
 
241
 
242
+ # ============================================================
243
+ # Model 5 (NEW): MOMENT-1-large as ANOMALY DETECTOR
244
+ # (MOMENT forecasting needs training — anomaly/reconstruction is zero-shot)
245
+ # ============================================================
246
+ _moment_cache = {"model": None, "lock": threading.Lock()}
247
 
248
 
249
+ def _get_moment():
250
+ with _moment_cache["lock"]:
251
+ if _moment_cache["model"] is None:
252
+ from momentfm import MOMENTPipeline
253
+ m = MOMENTPipeline.from_pretrained(
254
+ "AutonLab/MOMENT-1-large",
255
+ model_kwargs={"task_name": "reconstruction"},
256
+ )
257
+ m.init()
258
+ m.eval()
259
+ _moment_cache["model"] = m
260
+ return _moment_cache["model"]
261
+
262
+
263
+ def anomaly_moment(symbol: str, lookback: int = 512) -> dict:
264
+ """Detects anomalies in recent price action via reconstruction error.
265
+ Returns anomaly score (higher = more anomalous) and regime flag."""
 
 
 
 
 
 
 
266
  try:
267
+ # MOMENT requires exactly 512 timesteps
268
+ df = _load_ohlc(symbol, max(lookback, 512))
269
+ closes = df["Close"].values.astype(np.float32)[-512:]
270
+ if len(closes) < 512:
271
+ # Pad
272
+ padded = np.zeros(512, dtype=np.float32)
273
+ padded[-len(closes):] = closes
274
+ closes = padded
275
+ model = _get_moment()
276
+ # Normalize
277
+ mean_, std_ = closes.mean(), closes.std() or 1.0
278
+ norm = (closes - mean_) / std_
279
+ # MOMENT expects (batch, n_channels, seq_len)
280
+ x = torch.tensor(norm, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
281
+ mask = torch.ones_like(x[:, 0, :], dtype=torch.long)
282
+ with torch.no_grad():
283
+ output = model(x_enc=x, input_mask=mask)
284
+ recon = output.reconstruction[0, 0].cpu().numpy()
285
+ # Anomaly score per timestep = squared error, normalized
286
+ err = (norm - recon) ** 2
287
+ recent_err = float(err[-30:].mean()) # last 30 days
288
+ baseline_err = float(err[:-30].mean()) if len(err) > 30 else recent_err
289
+ ratio = recent_err / max(baseline_err, 1e-6)
290
+ # Regime flag: 1=normal, 2=elevated, 3=anomaly
291
+ if ratio > 2.5:
292
+ regime = "anomaly"
293
+ elif ratio > 1.5:
294
+ regime = "elevated"
295
+ else:
296
+ regime = "normal"
297
+ # Peak anomaly in last 30d
298
+ peak_idx = int(np.argmax(err[-30:]))
299
+ return {"status": "ok", "symbol": symbol, "model": "AutonLab/MOMENT-1-large",
300
+ "recent_err": round(recent_err, 4),
301
+ "baseline_err": round(baseline_err, 4),
302
+ "err_ratio": round(ratio, 3),
303
+ "regime": regime,
304
+ "peak_anomaly_days_ago": 30 - peak_idx,
305
+ "n_context": int(len(closes))}
306
  except Exception as e:
307
+ return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
308
 
309
 
310
+ # ============================================================
311
+ # Model 6: FinBERT sentiment (news via yfinance)
312
+ # ============================================================
313
+ _finbert_cache = {"pipe": None, "lock": threading.Lock()}
 
314
 
315
 
316
  def _get_finbert():
317
+ with _finbert_cache["lock"]:
318
+ if _finbert_cache["pipe"] is None:
319
+ from transformers import pipeline
320
+ _finbert_cache["pipe"] = pipeline("text-classification",
321
+ model="peejm/finbert-financial-sentiment",
322
+ device=-1, top_k=None)
323
+ return _finbert_cache["pipe"]
324
+
325
+
326
+ def _score_texts_finbert(texts: List[str]) -> Dict[str, Any]:
327
+ """Run FinBERT over a list of texts, return aggregate sentiment metrics."""
328
+ if not texts:
329
+ return {"n": 0, "sentiment_net": 0.0, "direction": 0, "pos": 0, "neg": 0, "neu": 0}
330
+ pipe = _get_finbert()
331
+ texts = [t[:512] for t in texts if t and t.strip()]
 
 
 
 
 
 
 
 
 
332
  if not texts:
333
+ return {"n": 0, "sentiment_net": 0.0, "direction": 0, "pos": 0, "neg": 0, "neu": 0}
334
+ results = pipe(texts, batch_size=8, truncation=True)
335
+ pos = neg = neu = 0
336
+ net = 0.0
337
+ for r in results:
338
+ # Result is list of {label, score} — take top
339
+ top = r[0] if isinstance(r, list) else r
340
+ label = str(top["label"]).lower()
341
+ score = float(top["score"])
342
+ if "pos" in label:
343
+ pos += 1
344
+ net += score
345
+ elif "neg" in label:
346
+ neg += 1
347
+ net -= score
348
+ else:
349
+ neu += 1
350
+ n = len(results)
351
+ mean_net = net / n if n > 0 else 0.0
352
+ direction = 1 if mean_net > 0.15 else (-1 if mean_net < -0.15 else 0)
353
+ return {"n": n, "sentiment_net": round(mean_net, 4), "direction": direction,
354
+ "pos": pos, "neg": neg, "neu": neu}
355
+
356
+
357
+ def score_sentiment(text: str) -> dict:
358
+ """Score single piece of text."""
359
  try:
360
+ res = _score_texts_finbert([text])
361
+ return {"status": "ok", **res}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  except Exception as e:
363
+ return {"status": "error", "error": str(e)}
364
 
365
 
366
+ def score_sentiment_for_symbol(symbol: str, max_articles: int = 20) -> dict:
367
+ """Fetch yfinance news and score via FinBERT."""
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  try:
369
+ import yfinance as yf
370
+ try:
371
+ from curl_cffi import requests as cffi_requests
372
+ session = cffi_requests.Session(impersonate="chrome")
373
+ except Exception:
374
+ session = None
375
  t = yf.Ticker(symbol, session=session) if session else yf.Ticker(symbol)
376
+ news = []
377
+ try:
378
+ news = t.news or []
379
+ except Exception as e:
380
+ return {"status": "error", "symbol": symbol,
381
+ "error": f"yfinance news fetch failed: {e}"}
382
+ titles = []
383
+ for item in news[:max_articles]:
384
+ # yfinance news can have content nested under "content" key
385
+ if "content" in item and isinstance(item["content"], dict):
386
+ title = item["content"].get("title") or ""
387
+ desc = item["content"].get("description") or ""
388
+ else:
389
+ title = item.get("title", "")
390
+ desc = item.get("summary", "")
391
+ txt = f"{title}. {desc}".strip().strip(".")
392
+ if txt:
393
+ titles.append(txt)
394
+ res = _score_texts_finbert(titles)
395
+ return {"status": "ok", "symbol": symbol, "source": "yfinance_news",
396
+ "n_articles": res["n"], "sentiment_net": res["sentiment_net"],
397
+ "direction": res["direction"],
398
+ "pos": res["pos"], "neg": res["neg"], "neu": res["neu"]}
399
  except Exception as e:
400
+ return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
401
+
402
+
403
+ # ============================================================
404
+ # NEW: GDELT news sentiment (global macro/event stream)
405
+ # ============================================================
406
+
407
+ def news_gdelt_for_symbol(symbol: str, company_name: str = "", days: int = 3,
408
+ max_articles: int = 30) -> dict:
409
+ """Fetch GDELT articles matching symbol/company, score sentiment.
410
+ Free, no key, 15-min refresh, 100+ languages (filtered to English)."""
411
+ try:
412
+ from gdeltdoc import GdeltDoc, Filters
413
+ # Query construction
414
+ # If company_name given, use it; else just symbol
415
+ keyword = company_name.strip() if company_name.strip() else symbol
416
+ timespan_map = {1: "1d", 2: "2d", 3: "3d", 7: "1w"}
417
+ timespan = timespan_map.get(int(days), f"{int(days)}d")
418
+ f = Filters(keyword=keyword, language="eng",
419
+ timespan=timespan, num_records=int(max_articles))
420
+ gd = GdeltDoc()
421
+ articles = gd.article_search(f)
422
+ if articles is None or len(articles) == 0:
423
+ return {"status": "ok", "symbol": symbol, "source": "gdelt",
424
+ "n_articles": 0, "sentiment_net": 0.0, "direction": 0,
425
+ "pos": 0, "neg": 0, "neu": 0, "top_domains": []}
426
+ titles = [t for t in articles["title"].tolist() if isinstance(t, str) and t]
427
+ # Deduplicate
428
+ seen = set()
429
+ deduped = []
430
+ for t in titles:
431
+ key = t[:120].lower()
432
+ if key not in seen:
433
+ seen.add(key)
434
+ deduped.append(t)
435
+ res = _score_texts_finbert(deduped)
436
+ # Top source domains
437
+ if "domain" in articles.columns:
438
+ top_domains = articles["domain"].value_counts().head(5).to_dict()
439
+ else:
440
+ top_domains = {}
441
+ return {"status": "ok", "symbol": symbol, "source": "gdelt",
442
+ "n_articles": res["n"], "sentiment_net": res["sentiment_net"],
443
+ "direction": res["direction"],
444
+ "pos": res["pos"], "neg": res["neg"], "neu": res["neu"],
445
+ "top_domains": top_domains,
446
+ "keyword_used": keyword, "timespan": timespan}
447
+ except Exception as e:
448
+ return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
449
+
450
+
451
+ # ============================================================
452
+ # NEW: Reddit retail sentiment (WSB + ISB + stocks + investing)
453
+ # ============================================================
454
+
455
+ _DEFAULT_SUBS = ["wallstreetbets", "stocks", "investing", "IndianStreetBets",
456
+ "DalalStreetTalks", "IndiaInvestments"]
457
+
458
+
459
+ def _fetch_reddit_posts(sub: str, query: str, time_filter: str = "week",
460
+ limit: int = 25) -> list:
461
+ """Fetch posts from Reddit public JSON API — no auth needed."""
462
+ import requests
463
+ url = f"https://www.reddit.com/r/{sub}/search.json"
464
+ params = {"q": query, "restrict_sr": "1", "sort": "top",
465
+ "t": time_filter, "limit": str(min(limit, 100))}
466
+ headers = {"User-Agent": "InvestmentOS/1.0 (ensemble analysis)"}
467
+ try:
468
+ r = requests.get(url, params=params, headers=headers, timeout=15)
469
+ if r.status_code != 200:
470
+ return []
471
+ data = r.json()
472
+ posts = []
473
+ for child in data.get("data", {}).get("children", []):
474
+ d = child.get("data", {})
475
+ posts.append({
476
+ "title": d.get("title", ""),
477
+ "selftext": d.get("selftext", "")[:1000],
478
+ "score": d.get("score", 0),
479
+ "num_comments": d.get("num_comments", 0),
480
+ "sub": sub,
481
+ "url": f"https://www.reddit.com{d.get('permalink', '')}",
482
+ })
483
+ return posts
484
+ except Exception:
485
+ return []
486
+
487
+
488
+ def reddit_sentiment_for_symbol(symbol: str, subs_csv: str = "",
489
+ max_posts_per_sub: int = 20,
490
+ time_filter: str = "week") -> dict:
491
+ """Search multiple subreddits for symbol mentions and score sentiment."""
492
  try:
493
+ import concurrent.futures
494
+ subs = [s.strip() for s in (subs_csv or "").split(",") if s.strip()]
495
+ if not subs:
496
+ subs = _DEFAULT_SUBS
497
+ # Query: symbol with optional $ prefix to catch ticker mentions
498
+ query = f'"{symbol}" OR "${symbol}"'
499
+
500
+ with concurrent.futures.ThreadPoolExecutor(max_workers=6) as ex:
501
+ futs = {ex.submit(_fetch_reddit_posts, s, query, time_filter, max_posts_per_sub): s
502
+ for s in subs}
503
+ all_posts = []
504
+ by_sub_count = {}
505
+ for fut in concurrent.futures.as_completed(futs):
506
+ sub = futs[fut]
507
+ try:
508
+ posts = fut.result()
509
+ except Exception:
510
+ posts = []
511
+ by_sub_count[sub] = len(posts)
512
+ all_posts.extend(posts)
513
+
514
+ # Build texts: weight higher-score posts by including selftext too
515
+ texts = []
516
+ for p in all_posts:
517
+ txt = p["title"]
518
+ if p["selftext"]:
519
+ txt = f"{p['title']}. {p['selftext'][:400]}"
520
+ if txt.strip():
521
+ texts.append(txt[:512])
522
+
523
+ if not texts:
524
+ return {"status": "ok", "symbol": symbol, "source": "reddit",
525
+ "n_mentions": 0, "sentiment_net": 0.0, "direction": 0,
526
+ "pos": 0, "neg": 0, "neu": 0, "by_sub": by_sub_count,
527
+ "subs_searched": subs}
528
+
529
+ res = _score_texts_finbert(texts)
530
+ # Attention metric: weighted score
531
+ total_score = sum(p["score"] for p in all_posts)
532
+ total_comments = sum(p["num_comments"] for p in all_posts)
533
+
534
+ return {"status": "ok", "symbol": symbol, "source": "reddit",
535
+ "n_mentions": res["n"],
536
+ "sentiment_net": res["sentiment_net"],
537
+ "direction": res["direction"],
538
+ "pos": res["pos"], "neg": res["neg"], "neu": res["neu"],
539
+ "by_sub": by_sub_count,
540
+ "total_upvotes": int(total_score),
541
+ "total_comments": int(total_comments),
542
+ "subs_searched": subs,
543
+ "query": query, "time_filter": time_filter}
544
  except Exception as e:
545
+ return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
546
+
547
+
548
+ # ============================================================
549
+ # Gradio Blocks with MCP exposure
550
+ # ============================================================
551
+
552
+ with gr.Blocks(title="Investment OS Multi-Model Space") as demo:
553
+ gr.Markdown("# Investment OS: Kronos + Chronos + TimesFM + TiRex + MOMENT + FinBERT + GDELT + Reddit")
554
+
555
+ with gr.Tab("Kronos (OHLCV TSFM)"):
556
+ sym = gr.Textbox(label="Symbol", value="AAPL")
557
+ lb = gr.Number(label="Lookback", value=180)
558
+ pd_ = gr.Number(label="Pred days", value=30)
559
+ out = gr.JSON(label="Forecast")
560
+ gr.Button("Forecast").click(forecast, [sym, lb, pd_], out, api_name="forecast")
561
+
562
+ with gr.Tab("Chronos (generic TSFM)"):
563
+ s2 = gr.Textbox(label="Symbol", value="AAPL")
564
+ l2 = gr.Number(label="Lookback", value=180)
565
+ p2 = gr.Number(label="Pred days", value=30)
566
+ o2 = gr.JSON(label="Forecast")
567
+ gr.Button("Forecast").click(forecast_chronos, [s2, l2, p2], o2, api_name="forecast_chronos")
568
+
569
+ with gr.Tab("TimesFM 2.5 (transformers)"):
570
+ s3 = gr.Textbox(label="Symbol", value="AAPL")
571
+ l3 = gr.Number(label="Lookback", value=180)
572
+ p3 = gr.Number(label="Pred days", value=30)
573
+ o3 = gr.JSON(label="Forecast")
574
+ gr.Button("Forecast").click(forecast_timesfm, [s3, l3, p3], o3, api_name="forecast_timesfm")
575
+
576
+ with gr.Tab("TiRex (xLSTM TSFM) NEW"):
577
+ s4 = gr.Textbox(label="Symbol", value="AAPL")
578
+ l4 = gr.Number(label="Lookback", value=180)
579
+ p4 = gr.Number(label="Pred days", value=30)
580
+ o4 = gr.JSON(label="Forecast")
581
+ gr.Button("Forecast").click(forecast_tirex, [s4, l4, p4], o4, api_name="forecast_tirex")
582
+
583
+ with gr.Tab("MOMENT Anomaly NEW"):
584
+ s5 = gr.Textbox(label="Symbol", value="AAPL")
585
+ l5 = gr.Number(label="Lookback (min 512)", value=512)
586
+ o5 = gr.JSON(label="Anomaly analysis")
587
+ gr.Button("Detect").click(anomaly_moment, [s5, l5], o5, api_name="anomaly_moment")
588
+
589
+ with gr.Tab("FinBERT text"):
590
+ t6 = gr.Textbox(label="Text", value="The company reported record earnings.")
591
+ o6 = gr.JSON(label="Sentiment")
592
+ gr.Button("Score").click(score_sentiment, t6, o6, api_name="score_sentiment")
593
+
594
+ with gr.Tab("FinBERT yfinance news"):
595
+ s7 = gr.Textbox(label="Symbol", value="AAPL")
596
+ m7 = gr.Number(label="Max articles", value=20)
597
+ o7 = gr.JSON(label="Sentiment")
598
+ gr.Button("Score").click(score_sentiment_for_symbol, [s7, m7], o7, api_name="score_sentiment_for_symbol")
599
+
600
+ with gr.Tab("GDELT news NEW"):
601
+ s8 = gr.Textbox(label="Symbol", value="AAPL")
602
+ c8 = gr.Textbox(label="Company name (optional)", value="Apple")
603
+ d8 = gr.Number(label="Days", value=3)
604
+ m8 = gr.Number(label="Max articles", value=30)
605
+ o8 = gr.JSON(label="GDELT sentiment")
606
+ gr.Button("Fetch").click(news_gdelt_for_symbol, [s8, c8, d8, m8], o8, api_name="news_gdelt_for_symbol")
607
+
608
+ with gr.Tab("Reddit sentiment NEW"):
609
+ s9 = gr.Textbox(label="Symbol", value="AAPL")
610
+ sub9 = gr.Textbox(label="Subs CSV (blank = defaults)",
611
+ value="wallstreetbets,stocks,investing,IndianStreetBets,DalalStreetTalks,IndiaInvestments")
612
+ m9 = gr.Number(label="Max posts per sub", value=20)
613
+ t9 = gr.Textbox(label="Time filter", value="week")
614
+ o9 = gr.JSON(label="Reddit sentiment")
615
+ gr.Button("Fetch").click(reddit_sentiment_for_symbol, [s9, sub9, m9, t9], o9,
616
+ api_name="reddit_sentiment_for_symbol")
617
 
618
 
619
  if __name__ == "__main__":
620
+ demo.launch(mcp_server=True, server_name="0.0.0.0", server_port=7860)