Spaces:
Sleeping
Sleeping
v3: TimesFM 2.5 + TiRex + MOMENT + GDELT + Reddit
Browse files- Dockerfile +50 -24
- README.md +17 -15
- app.py +564 -358
Dockerfile
CHANGED
|
@@ -1,38 +1,64 @@
|
|
| 1 |
FROM python:3.11-slim
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
RUN useradd -m -u 1000 user
|
| 5 |
USER user
|
| 6 |
-
ENV HOME=/home/user
|
| 7 |
-
PATH=/home/user/.local/bin:$PATH \
|
| 8 |
-
PYTHONUNBUFFERED=1 \
|
| 9 |
-
GRADIO_SERVER_NAME=0.0.0.0 \
|
| 10 |
-
GRADIO_SERVER_PORT=7860 \
|
| 11 |
-
HF_HOME=/home/user/.cache/huggingface
|
| 12 |
-
|
| 13 |
WORKDIR /home/user/app
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
-
torch==2.4.1
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
RUN pip install --
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"huggingface_hub>=0.27.0,<1.0" \
|
| 23 |
-
"
|
| 24 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
"yfinance>=0.2.50" \
|
| 26 |
"curl_cffi>=0.7" \
|
| 27 |
-
"
|
| 28 |
-
"
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
EXPOSE 7860
|
| 38 |
CMD ["python", "app.py"]
|
|
|
|
| 1 |
FROM python:3.11-slim
|
| 2 |
|
| 3 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 4 |
+
git build-essential curl wget \
|
| 5 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 6 |
+
|
| 7 |
+
# Create user matching HF Space conventions
|
| 8 |
RUN useradd -m -u 1000 user
|
| 9 |
USER user
|
| 10 |
+
ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
WORKDIR /home/user/app
|
| 12 |
|
| 13 |
+
# Disable TiRex CUDA kernels (we're on CPU-only Space)
|
| 14 |
+
ENV TIREX_NO_CUDA=1 XLSTM_USE_CUDA_KERNELS=0
|
|
|
|
| 15 |
|
| 16 |
+
# Core: Torch CPU (keep 2.4.1 — Kronos meta-tensor compatibility)
|
| 17 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 18 |
+
pip install --no-cache-dir torch==2.4.1 --index-url https://download.pytorch.org/whl/cpu
|
| 19 |
+
|
| 20 |
+
# Transformers bumped to 4.55.0 for TimesFm2_5 support
|
| 21 |
+
RUN pip install --no-cache-dir \
|
| 22 |
+
"transformers==4.55.0" \
|
| 23 |
"huggingface_hub>=0.27.0,<1.0" \
|
| 24 |
+
"gradio[mcp]==5.30.0" \
|
| 25 |
+
"websockets>=13" \
|
| 26 |
+
"einops" \
|
| 27 |
+
"safetensors" \
|
| 28 |
+
"pandas" \
|
| 29 |
+
"numpy<2.3"
|
| 30 |
+
|
| 31 |
+
# Chronos + TimesFM (fallback) + yfinance + sentiment
|
| 32 |
+
RUN pip install --no-cache-dir \
|
| 33 |
+
"chronos-forecasting>=1.5.2" \
|
| 34 |
+
"timesfm[torch]>=1.3.0" \
|
| 35 |
"yfinance>=0.2.50" \
|
| 36 |
"curl_cffi>=0.7" \
|
| 37 |
+
"sentencepiece" \
|
| 38 |
+
"tokenizers"
|
| 39 |
+
|
| 40 |
+
# NEW: MOMENT (anomaly detection)
|
| 41 |
+
RUN pip install --no-cache-dir "momentfm>=0.1.4"
|
| 42 |
+
|
| 43 |
+
# NEW: GDELT (news streaming)
|
| 44 |
+
RUN pip install --no-cache-dir "gdeltdoc>=1.5.0"
|
| 45 |
+
|
| 46 |
+
# NEW: requests for Reddit
|
| 47 |
+
RUN pip install --no-cache-dir "requests>=2.31" "beautifulsoup4>=4.12"
|
| 48 |
+
|
| 49 |
+
# NEW: TiRex (xLSTM TSFM) — install from git, CPU experimental mode
|
| 50 |
+
# Use || true so build doesn't fail if tirex install hits issues — endpoint will error gracefully
|
| 51 |
+
RUN pip install --no-cache-dir git+https://github.com/NX-AI/tirex.git || \
|
| 52 |
+
echo "WARNING: TiRex install failed — /forecast_tirex will return error"
|
| 53 |
+
|
| 54 |
+
# Kronos model repo
|
| 55 |
+
USER user
|
| 56 |
+
RUN mkdir -p /home/user/app/model
|
| 57 |
+
COPY --chown=user ./model /home/user/app/model
|
| 58 |
+
COPY --chown=user ./app.py /home/user/app/app.py
|
| 59 |
|
| 60 |
+
# Pre-create HF cache dirs with right perms
|
| 61 |
+
RUN mkdir -p /home/user/.cache/huggingface
|
| 62 |
|
| 63 |
EXPOSE 7860
|
| 64 |
CMD ["python", "app.py"]
|
README.md
CHANGED
|
@@ -1,22 +1,24 @@
|
|
| 1 |
---
|
| 2 |
-
title: Kronos Forecast
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
app_port: 7860
|
| 8 |
-
pinned:
|
| 9 |
license: mit
|
| 10 |
-
|
| 11 |
-
- mcp-server
|
| 12 |
-
- finance
|
| 13 |
-
- forecasting
|
| 14 |
-
- kronos
|
| 15 |
---
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
-
|
| 21 |
-
-
|
| 22 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Kronos Forecast Multi-Model
|
| 3 |
+
emoji: 🦜
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
app_port: 7860
|
| 8 |
+
pinned: true
|
| 9 |
license: mit
|
| 10 |
+
short_description: TSFM + Sentiment + News + Reddit ensemble
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Investment OS Multi-Model Space
|
| 14 |
|
| 15 |
+
9 endpoints exposed via Gradio + MCP:
|
| 16 |
+
- `/forecast` — Kronos (OHLCV finance TSFM)
|
| 17 |
+
- `/forecast_chronos` — Chronos-bolt-tiny (generic TSFM)
|
| 18 |
+
- `/forecast_timesfm` — TimesFM 2.5 (Google)
|
| 19 |
+
- `/forecast_tirex` — TiRex (NX-AI xLSTM TSFM)
|
| 20 |
+
- `/anomaly_moment` — MOMENT-1-large (anomaly detection)
|
| 21 |
+
- `/score_sentiment` — FinBERT text
|
| 22 |
+
- `/score_sentiment_for_symbol` — FinBERT × yfinance news
|
| 23 |
+
- `/news_gdelt_for_symbol` — FinBERT × GDELT global news
|
| 24 |
+
- `/reddit_sentiment_for_symbol` — FinBERT × Reddit retail sentiment
|
app.py
CHANGED
|
@@ -1,414 +1,620 @@
|
|
| 1 |
-
"""
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
- /forecast_chronos — amazon/chronos-bolt-tiny (generic TSFM, CPU-fast)
|
| 6 |
-
- /forecast_timesfm — google/timesfm-2.5-200m-pytorch (Google TSFM)
|
| 7 |
-
- /score_sentiment — ProsusAI/finbert (financial sentiment)
|
| 8 |
|
| 9 |
-
|
| 10 |
-
"""
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
import pandas as pd
|
| 15 |
import torch
|
| 16 |
-
import yfinance as yf
|
| 17 |
import gradio as gr
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# -----------------------------------------------------------------------------
|
| 23 |
-
|
| 24 |
-
def _load_ohlc(symbol: str, lookback_days: int) -> pd.DataFrame:
|
| 25 |
-
def _tidy(df):
|
| 26 |
-
if df is None or df.empty:
|
| 27 |
-
return pd.DataFrame()
|
| 28 |
-
if isinstance(df.columns, pd.MultiIndex):
|
| 29 |
-
df.columns = df.columns.get_level_values(0)
|
| 30 |
-
df = df.reset_index()
|
| 31 |
-
df.columns = [str(c).lower() for c in df.columns]
|
| 32 |
-
if "date" not in df.columns and "datetime" in df.columns:
|
| 33 |
-
df["date"] = df["datetime"]
|
| 34 |
-
if "date" not in df.columns:
|
| 35 |
-
return pd.DataFrame()
|
| 36 |
-
df["timestamps"] = pd.to_datetime(df["date"])
|
| 37 |
-
keep = ["timestamps", "open", "high", "low", "close", "volume"]
|
| 38 |
-
df = df[[c for c in keep if c in df.columns]].dropna().tail(lookback_days).reset_index(drop=True)
|
| 39 |
-
return df
|
| 40 |
-
|
| 41 |
-
# Primary: yf.download. Fallback: yf.Ticker().history(). Uses curl_cffi chrome impersonation if available.
|
| 42 |
-
session = None
|
| 43 |
try:
|
| 44 |
-
from curl_cffi import requests as
|
| 45 |
-
session =
|
| 46 |
except Exception:
|
| 47 |
session = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
period = f"{lookback_days + 10}d"
|
| 50 |
-
try:
|
| 51 |
-
df = yf.download(symbol, period=period, interval="1d",
|
| 52 |
-
progress=False, auto_adjust=False,
|
| 53 |
-
session=session) if session else \
|
| 54 |
-
yf.download(symbol, period=period, interval="1d",
|
| 55 |
-
progress=False, auto_adjust=False)
|
| 56 |
-
out = _tidy(df)
|
| 57 |
-
if len(out) >= 32:
|
| 58 |
-
return out
|
| 59 |
-
except Exception:
|
| 60 |
-
pass
|
| 61 |
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
try:
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
-
def
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
-
def
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
try:
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
except Exception as e:
|
| 131 |
-
return {"status": "error", "
|
| 132 |
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
#
|
| 136 |
-
#
|
| 137 |
-
|
| 138 |
-
_chronos = None
|
| 139 |
|
| 140 |
|
| 141 |
-
def
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def forecast_chronos(symbol: str, lookback_days: int = 180, pred_days: int = 30) -> dict:
|
| 152 |
-
"""Chronos-bolt-tiny forecast on close prices."""
|
| 153 |
-
symbol = (symbol or "").strip().upper()
|
| 154 |
-
if not symbol:
|
| 155 |
-
return {"status": "error", "error": "empty symbol"}
|
| 156 |
-
lookback_days, pred_days = _clamp(lookback_days, pred_days)
|
| 157 |
try:
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
"status": "
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
except Exception as e:
|
| 182 |
-
return {"status": "error", "
|
| 183 |
|
| 184 |
|
| 185 |
-
#
|
| 186 |
-
#
|
| 187 |
-
# -
|
| 188 |
-
|
| 189 |
-
|
| 190 |
|
| 191 |
|
| 192 |
-
def
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
def forecast_timesfm(symbol: str, lookback_days: int = 180, pred_days: int = 30) -> dict:
|
| 211 |
-
"""TimesFM 2.5 (200M) forecast on close prices."""
|
| 212 |
-
symbol = (symbol or "").strip().upper()
|
| 213 |
-
if not symbol:
|
| 214 |
-
return {"status": "error", "error": "empty symbol"}
|
| 215 |
-
lookback_days, pred_days = _clamp(lookback_days, pred_days)
|
| 216 |
try:
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
except Exception as e:
|
| 239 |
-
return {"status": "error", "
|
| 240 |
|
| 241 |
|
| 242 |
-
#
|
| 243 |
-
# FinBERT
|
| 244 |
-
#
|
| 245 |
-
|
| 246 |
-
_finbert = None
|
| 247 |
|
| 248 |
|
| 249 |
def _get_finbert():
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
try:
|
| 266 |
-
texts = _json.loads(texts_json)
|
| 267 |
-
if isinstance(texts, str):
|
| 268 |
-
texts = [texts]
|
| 269 |
-
if not isinstance(texts, list):
|
| 270 |
-
texts = [str(texts)]
|
| 271 |
-
except Exception:
|
| 272 |
-
texts = [t.strip() for t in str(texts_json).split("\n") if t.strip()]
|
| 273 |
-
texts = [str(t) for t in texts][:50]
|
| 274 |
if not texts:
|
| 275 |
-
return {"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
try:
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
per = []
|
| 280 |
-
for item in raw:
|
| 281 |
-
entries = item if isinstance(item, list) else [item]
|
| 282 |
-
p = n = u = 0.0
|
| 283 |
-
for e in entries:
|
| 284 |
-
lbl = str(e.get("label", "")).lower()
|
| 285 |
-
sc = float(e.get("score", 0))
|
| 286 |
-
if lbl.startswith("pos"): p = sc
|
| 287 |
-
elif lbl.startswith("neg"): n = sc
|
| 288 |
-
elif lbl.startswith("neu"): u = sc
|
| 289 |
-
pos_sum += p; neg_sum += n; neu_sum += u
|
| 290 |
-
per.append({"pos": round(p, 4), "neg": round(n, 4), "neu": round(u, 4)})
|
| 291 |
-
n = len(texts)
|
| 292 |
-
return {
|
| 293 |
-
"status": "ok", "model": FINBERT_MODEL_ID, "n": n,
|
| 294 |
-
"net": round((pos_sum - neg_sum) / n, 4),
|
| 295 |
-
"pos": round(pos_sum / n, 4),
|
| 296 |
-
"neg": round(neg_sum / n, 4),
|
| 297 |
-
"neu": round(neu_sum / n, 4),
|
| 298 |
-
"per_text": per,
|
| 299 |
-
}
|
| 300 |
except Exception as e:
|
| 301 |
-
return {"status": "error", "error":
|
| 302 |
|
| 303 |
|
| 304 |
-
def score_sentiment_for_symbol(symbol: str,
|
| 305 |
-
"""Fetch
|
| 306 |
-
|
| 307 |
-
Returns aggregated sentiment + direction signal (+1/−1/0) suitable for ensemble voting.
|
| 308 |
-
"""
|
| 309 |
-
symbol = (symbol or "").strip().upper()
|
| 310 |
-
if not symbol:
|
| 311 |
-
return {"status": "error", "error": "empty symbol"}
|
| 312 |
-
max_items = int(max(1, min(int(max_items or 20), 50)))
|
| 313 |
-
# yfinance news via curl_cffi session if available
|
| 314 |
-
try:
|
| 315 |
-
from curl_cffi import requests as cureq
|
| 316 |
-
session = cureq.Session(impersonate="chrome")
|
| 317 |
-
except Exception:
|
| 318 |
-
session = None
|
| 319 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
t = yf.Ticker(symbol, session=session) if session else yf.Ticker(symbol)
|
| 321 |
-
news =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
except Exception as e:
|
| 323 |
-
return {"status": "error", "
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
if
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
try:
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
except Exception as e:
|
| 347 |
-
return {"status": "error", "
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
"
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
gr.
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
|
| 413 |
if __name__ == "__main__":
|
| 414 |
-
demo.launch(server_name="0.0.0.0", server_port=7860
|
|
|
|
| 1 |
+
"""Kronos + Chronos + TimesFM + TiRex + MOMENT + FinBERT + GDELT + Reddit — Investment OS Space."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
|
| 4 |
+
import os, sys, time, json, traceback, threading, warnings
|
| 5 |
+
from typing import List, Optional, Tuple, Dict, Any
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
warnings.filterwarnings("ignore")
|
| 8 |
+
os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
|
| 9 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 10 |
+
# Disable TiRex custom CUDA kernels (we're on CPU)
|
| 11 |
+
os.environ.setdefault("TIREX_NO_CUDA", "1")
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
import pandas as pd
|
| 15 |
import torch
|
|
|
|
| 16 |
import gradio as gr
|
| 17 |
|
| 18 |
+
# ============================================================
|
| 19 |
+
# Shared yfinance OHLC loader (new session each call to avoid race)
|
| 20 |
+
# ============================================================
|
| 21 |
|
| 22 |
+
def _load_ohlc(symbol: str, lookback: int = 180) -> pd.DataFrame:
|
| 23 |
+
import yfinance as yf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
try:
|
| 25 |
+
from curl_cffi import requests as cffi_requests
|
| 26 |
+
session = cffi_requests.Session(impersonate="chrome")
|
| 27 |
except Exception:
|
| 28 |
session = None
|
| 29 |
+
end = pd.Timestamp.utcnow().tz_localize(None)
|
| 30 |
+
start = end - pd.Timedelta(days=int(lookback * 2.2)) # account for weekends/holidays
|
| 31 |
+
kwargs = dict(start=start.strftime("%Y-%m-%d"), end=(end + pd.Timedelta(days=1)).strftime("%Y-%m-%d"),
|
| 32 |
+
interval="1d", progress=False, auto_adjust=False, actions=False, threads=False)
|
| 33 |
+
if session is not None:
|
| 34 |
+
kwargs["session"] = session
|
| 35 |
+
df = yf.download(symbol, **kwargs)
|
| 36 |
+
if df is None or len(df) == 0:
|
| 37 |
+
raise RuntimeError(f"No data for {symbol}")
|
| 38 |
+
if isinstance(df.columns, pd.MultiIndex):
|
| 39 |
+
df.columns = df.columns.get_level_values(0)
|
| 40 |
+
df = df.dropna().tail(lookback).reset_index()
|
| 41 |
+
need = {"Open", "High", "Low", "Close", "Volume"}
|
| 42 |
+
if not need.issubset(set(df.columns)):
|
| 43 |
+
raise RuntimeError(f"Missing columns for {symbol}: got {list(df.columns)}")
|
| 44 |
+
return df
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ============================================================
|
| 48 |
+
# Model 1: Kronos (finance-native OHLCV foundation model)
|
| 49 |
+
# ============================================================
|
| 50 |
+
_kronos_cache = {"model": None, "tok": None, "pred": None, "lock": threading.Lock()}
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
def _get_kronos():
|
| 54 |
+
with _kronos_cache["lock"]:
|
| 55 |
+
if _kronos_cache["pred"] is None:
|
| 56 |
+
from model import Kronos, KronosTokenizer, KronosPredictor
|
| 57 |
+
tok = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
|
| 58 |
+
mdl = Kronos.from_pretrained("NeoQuasar/Kronos-small")
|
| 59 |
+
_kronos_cache["tok"] = tok
|
| 60 |
+
_kronos_cache["model"] = mdl
|
| 61 |
+
_kronos_cache["pred"] = KronosPredictor(model=mdl, tokenizer=tok, device="cpu", max_context=512)
|
| 62 |
+
return _kronos_cache["pred"]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def forecast(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict:
|
| 66 |
try:
|
| 67 |
+
df = _load_ohlc(symbol, lookback)
|
| 68 |
+
pred = _get_kronos()
|
| 69 |
+
x_df = df[["Open", "High", "Low", "Close", "Volume"]].copy()
|
| 70 |
+
x_ts = pd.to_datetime(df["Date"])
|
| 71 |
+
last = x_ts.iloc[-1]
|
| 72 |
+
y_ts = pd.date_range(start=last + pd.Timedelta(days=1), periods=pred_days, freq="B")
|
| 73 |
+
out = pred.predict(df=x_df, x_timestamp=x_ts, y_timestamp=y_ts, pred_len=pred_days, T=1.0, top_p=0.9, sample_count=1, verbose=False)
|
| 74 |
+
last_close = float(x_df["Close"].iloc[-1])
|
| 75 |
+
pred_close = float(out["close"].iloc[-1])
|
| 76 |
+
mean_close = float(out["close"].mean())
|
| 77 |
+
min_close = float(out["close"].min())
|
| 78 |
+
max_close = float(out["close"].max())
|
| 79 |
+
pct = (pred_close - last_close) / last_close * 100
|
| 80 |
+
direction = 1 if pct > 2 else (-1 if pct < -2 else 0)
|
| 81 |
+
return {"status": "ok", "symbol": symbol, "model": "NeoQuasar/Kronos-small",
|
| 82 |
+
"last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
|
| 83 |
+
"pct_change": round(pct, 3), "direction": direction,
|
| 84 |
+
"n_lookback": int(len(x_df)), "pred_days": pred_days,
|
| 85 |
+
"pred_mean_close": round(mean_close, 4), "pred_min_close": round(min_close, 4),
|
| 86 |
+
"pred_max_close": round(max_close, 4)}
|
| 87 |
+
except Exception as e:
|
| 88 |
+
return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
|
| 89 |
|
| 90 |
|
| 91 |
+
# ============================================================
|
| 92 |
+
# Model 2: Chronos-bolt-tiny (generic TSFM)
|
| 93 |
+
# ============================================================
|
| 94 |
+
_chronos_cache = {"pipe": None, "lock": threading.Lock()}
|
| 95 |
|
| 96 |
|
| 97 |
+
def _get_chronos():
|
| 98 |
+
with _chronos_cache["lock"]:
|
| 99 |
+
if _chronos_cache["pipe"] is None:
|
| 100 |
+
from chronos import BaseChronosPipeline
|
| 101 |
+
_chronos_cache["pipe"] = BaseChronosPipeline.from_pretrained(
|
| 102 |
+
"amazon/chronos-bolt-tiny", device_map="cpu", torch_dtype=torch.float32)
|
| 103 |
+
return _chronos_cache["pipe"]
|
| 104 |
|
| 105 |
|
| 106 |
+
def forecast_chronos(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict:
|
| 107 |
+
try:
|
| 108 |
+
df = _load_ohlc(symbol, lookback)
|
| 109 |
+
closes = df["Close"].values.astype(np.float32)
|
| 110 |
+
pipe = _get_chronos()
|
| 111 |
+
ctx = torch.tensor(closes, dtype=torch.float32)
|
| 112 |
+
quantiles, mean = pipe.predict_quantiles(context=ctx, prediction_length=int(pred_days),
|
| 113 |
+
quantile_levels=[0.1, 0.5, 0.9])
|
| 114 |
+
mean_pred = mean[0].numpy()
|
| 115 |
+
low_pred = quantiles[0, :, 0].numpy()
|
| 116 |
+
high_pred = quantiles[0, :, 2].numpy()
|
| 117 |
+
last_close = float(closes[-1])
|
| 118 |
+
pred_close = float(mean_pred[-1])
|
| 119 |
+
pct = (pred_close - last_close) / last_close * 100
|
| 120 |
+
direction = 1 if pct > 2 else (-1 if pct < -2 else 0)
|
| 121 |
+
return {"status": "ok", "symbol": symbol, "model": "amazon/chronos-bolt-tiny",
|
| 122 |
+
"last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
|
| 123 |
+
"pct_change": round(pct, 3), "direction": direction,
|
| 124 |
+
"n_lookback": int(len(closes)), "pred_days": int(pred_days),
|
| 125 |
+
"pred_mean_close": round(float(mean_pred.mean()), 4),
|
| 126 |
+
"pred_low_close": round(float(low_pred.min()), 4),
|
| 127 |
+
"pred_high_close": round(float(high_pred.max()), 4)}
|
| 128 |
+
except Exception as e:
|
| 129 |
+
return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
|
| 130 |
+
|
| 131 |
|
| 132 |
+
# ============================================================
|
| 133 |
+
# Model 3: TimesFM 2.5 via transformers (UPGRADED from 2.0)
|
| 134 |
+
# ============================================================
|
| 135 |
+
_timesfm_cache = {"model": None, "lock": threading.Lock()}
|
| 136 |
|
| 137 |
|
| 138 |
+
def _get_timesfm():
|
| 139 |
+
with _timesfm_cache["lock"]:
|
| 140 |
+
if _timesfm_cache["model"] is None:
|
| 141 |
+
try:
|
| 142 |
+
from transformers import TimesFm2_5ModelForPrediction
|
| 143 |
+
m = TimesFm2_5ModelForPrediction.from_pretrained(
|
| 144 |
+
"google/timesfm-2.5-200m-transformers")
|
| 145 |
+
m = m.to(torch.float32).eval()
|
| 146 |
+
_timesfm_cache["model"] = m
|
| 147 |
+
_timesfm_cache["version"] = "2.5"
|
| 148 |
+
except Exception:
|
| 149 |
+
# Fallback to 2.0 if 2.5 unavailable in transformers version
|
| 150 |
+
from transformers import TimesFmModelForPrediction
|
| 151 |
+
m = TimesFmModelForPrediction.from_pretrained(
|
| 152 |
+
"google/timesfm-2.0-500m-pytorch")
|
| 153 |
+
m = m.to(torch.float32).eval()
|
| 154 |
+
_timesfm_cache["model"] = m
|
| 155 |
+
_timesfm_cache["version"] = "2.0"
|
| 156 |
+
return _timesfm_cache["model"], _timesfm_cache["version"]
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def forecast_timesfm(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict:
|
| 160 |
try:
|
| 161 |
+
df = _load_ohlc(symbol, lookback)
|
| 162 |
+
closes = df["Close"].values.astype(np.float32)
|
| 163 |
+
model, ver = _get_timesfm()
|
| 164 |
+
past = [torch.tensor(closes, dtype=torch.float32)]
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
if ver == "2.5":
|
| 167 |
+
outputs = model(past_values=past, forecast_context_len=1024)
|
| 168 |
+
mean_pred = outputs.mean_predictions[0].float().cpu().numpy()
|
| 169 |
+
else:
|
| 170 |
+
# v2.0 transformers API
|
| 171 |
+
freq = torch.tensor([0], dtype=torch.long)
|
| 172 |
+
outputs = model(past_values=past, freq=freq, return_dict=True)
|
| 173 |
+
mean_pred = outputs.mean_predictions[0].float().cpu().numpy()
|
| 174 |
+
# Slice to pred_days
|
| 175 |
+
horizon = min(int(pred_days), len(mean_pred))
|
| 176 |
+
mean_pred = mean_pred[:horizon]
|
| 177 |
+
last_close = float(closes[-1])
|
| 178 |
+
pred_close = float(mean_pred[-1])
|
| 179 |
+
pct = (pred_close - last_close) / last_close * 100
|
| 180 |
+
direction = 1 if pct > 2 else (-1 if pct < -2 else 0)
|
| 181 |
+
return {"status": "ok", "symbol": symbol,
|
| 182 |
+
"model": f"google/timesfm-{ver}",
|
| 183 |
+
"last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
|
| 184 |
+
"pct_change": round(pct, 3), "direction": direction,
|
| 185 |
+
"n_lookback": int(len(closes)), "pred_days": horizon,
|
| 186 |
+
"pred_mean_close": round(float(mean_pred.mean()), 4),
|
| 187 |
+
"pred_min_close": round(float(mean_pred.min()), 4),
|
| 188 |
+
"pred_max_close": round(float(mean_pred.max()), 4)}
|
| 189 |
except Exception as e:
|
| 190 |
+
return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
|
| 191 |
|
| 192 |
|
| 193 |
+
# ============================================================
|
| 194 |
+
# Model 4 (NEW): TiRex (35M xLSTM TSFM, CPU experimental)
|
| 195 |
+
# ============================================================
|
| 196 |
+
_tirex_cache = {"model": None, "lock": threading.Lock()}
|
|
|
|
| 197 |
|
| 198 |
|
| 199 |
+
def _get_tirex():
|
| 200 |
+
with _tirex_cache["lock"]:
|
| 201 |
+
if _tirex_cache["model"] is None:
|
| 202 |
+
from tirex import load_model
|
| 203 |
+
_tirex_cache["model"] = load_model("NX-AI/TiRex")
|
| 204 |
+
return _tirex_cache["model"]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def forecast_tirex(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
try:
|
| 209 |
+
df = _load_ohlc(symbol, lookback)
|
| 210 |
+
closes = df["Close"].values.astype(np.float32)
|
| 211 |
+
model = _get_tirex()
|
| 212 |
+
# TiRex expects (batch, seq_len)
|
| 213 |
+
ctx = torch.tensor(closes, dtype=torch.float32).unsqueeze(0)
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
result = model.forecast(context=ctx, prediction_length=int(pred_days))
|
| 216 |
+
# TiRex returns (quantiles, mean) tuple in newer versions
|
| 217 |
+
if isinstance(result, tuple) and len(result) == 2:
|
| 218 |
+
_, mean_pred = result
|
| 219 |
+
else:
|
| 220 |
+
mean_pred = result
|
| 221 |
+
mean_arr = mean_pred[0].float().cpu().numpy() if hasattr(mean_pred, "cpu") else np.asarray(mean_pred)[0]
|
| 222 |
+
# Check for NaN (TiRex CPU may degrade)
|
| 223 |
+
if np.isnan(mean_arr).any():
|
| 224 |
+
return {"status": "error", "symbol": symbol,
|
| 225 |
+
"error": "TiRex returned NaN (CPU mode is experimental)",
|
| 226 |
+
"model": "NX-AI/TiRex"}
|
| 227 |
+
last_close = float(closes[-1])
|
| 228 |
+
pred_close = float(mean_arr[-1])
|
| 229 |
+
pct = (pred_close - last_close) / last_close * 100
|
| 230 |
+
direction = 1 if pct > 2 else (-1 if pct < -2 else 0)
|
| 231 |
+
return {"status": "ok", "symbol": symbol, "model": "NX-AI/TiRex",
|
| 232 |
+
"last_close": round(last_close, 4), "predicted_close": round(pred_close, 4),
|
| 233 |
+
"pct_change": round(pct, 3), "direction": direction,
|
| 234 |
+
"n_lookback": int(len(closes)), "pred_days": int(pred_days),
|
| 235 |
+
"pred_mean_close": round(float(mean_arr.mean()), 4),
|
| 236 |
+
"pred_min_close": round(float(mean_arr.min()), 4),
|
| 237 |
+
"pred_max_close": round(float(mean_arr.max()), 4)}
|
| 238 |
except Exception as e:
|
| 239 |
+
return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
|
| 240 |
|
| 241 |
|
| 242 |
+
# ============================================================
|
| 243 |
+
# Model 5 (NEW): MOMENT-1-large as ANOMALY DETECTOR
|
| 244 |
+
# (MOMENT forecasting needs training — anomaly/reconstruction is zero-shot)
|
| 245 |
+
# ============================================================
|
| 246 |
+
_moment_cache = {"model": None, "lock": threading.Lock()}
|
| 247 |
|
| 248 |
|
| 249 |
+
def _get_moment():
|
| 250 |
+
with _moment_cache["lock"]:
|
| 251 |
+
if _moment_cache["model"] is None:
|
| 252 |
+
from momentfm import MOMENTPipeline
|
| 253 |
+
m = MOMENTPipeline.from_pretrained(
|
| 254 |
+
"AutonLab/MOMENT-1-large",
|
| 255 |
+
model_kwargs={"task_name": "reconstruction"},
|
| 256 |
+
)
|
| 257 |
+
m.init()
|
| 258 |
+
m.eval()
|
| 259 |
+
_moment_cache["model"] = m
|
| 260 |
+
return _moment_cache["model"]
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def anomaly_moment(symbol: str, lookback: int = 512) -> dict:
|
| 264 |
+
"""Detects anomalies in recent price action via reconstruction error.
|
| 265 |
+
Returns anomaly score (higher = more anomalous) and regime flag."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
try:
|
| 267 |
+
# MOMENT requires exactly 512 timesteps
|
| 268 |
+
df = _load_ohlc(symbol, max(lookback, 512))
|
| 269 |
+
closes = df["Close"].values.astype(np.float32)[-512:]
|
| 270 |
+
if len(closes) < 512:
|
| 271 |
+
# Pad
|
| 272 |
+
padded = np.zeros(512, dtype=np.float32)
|
| 273 |
+
padded[-len(closes):] = closes
|
| 274 |
+
closes = padded
|
| 275 |
+
model = _get_moment()
|
| 276 |
+
# Normalize
|
| 277 |
+
mean_, std_ = closes.mean(), closes.std() or 1.0
|
| 278 |
+
norm = (closes - mean_) / std_
|
| 279 |
+
# MOMENT expects (batch, n_channels, seq_len)
|
| 280 |
+
x = torch.tensor(norm, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
| 281 |
+
mask = torch.ones_like(x[:, 0, :], dtype=torch.long)
|
| 282 |
+
with torch.no_grad():
|
| 283 |
+
output = model(x_enc=x, input_mask=mask)
|
| 284 |
+
recon = output.reconstruction[0, 0].cpu().numpy()
|
| 285 |
+
# Anomaly score per timestep = squared error, normalized
|
| 286 |
+
err = (norm - recon) ** 2
|
| 287 |
+
recent_err = float(err[-30:].mean()) # last 30 days
|
| 288 |
+
baseline_err = float(err[:-30].mean()) if len(err) > 30 else recent_err
|
| 289 |
+
ratio = recent_err / max(baseline_err, 1e-6)
|
| 290 |
+
# Regime flag: 1=normal, 2=elevated, 3=anomaly
|
| 291 |
+
if ratio > 2.5:
|
| 292 |
+
regime = "anomaly"
|
| 293 |
+
elif ratio > 1.5:
|
| 294 |
+
regime = "elevated"
|
| 295 |
+
else:
|
| 296 |
+
regime = "normal"
|
| 297 |
+
# Peak anomaly in last 30d
|
| 298 |
+
peak_idx = int(np.argmax(err[-30:]))
|
| 299 |
+
return {"status": "ok", "symbol": symbol, "model": "AutonLab/MOMENT-1-large",
|
| 300 |
+
"recent_err": round(recent_err, 4),
|
| 301 |
+
"baseline_err": round(baseline_err, 4),
|
| 302 |
+
"err_ratio": round(ratio, 3),
|
| 303 |
+
"regime": regime,
|
| 304 |
+
"peak_anomaly_days_ago": 30 - peak_idx,
|
| 305 |
+
"n_context": int(len(closes))}
|
| 306 |
except Exception as e:
|
| 307 |
+
return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
|
| 308 |
|
| 309 |
|
| 310 |
+
# ============================================================
|
| 311 |
+
# Model 6: FinBERT sentiment (news via yfinance)
|
| 312 |
+
# ============================================================
|
| 313 |
+
_finbert_cache = {"pipe": None, "lock": threading.Lock()}
|
|
|
|
| 314 |
|
| 315 |
|
| 316 |
def _get_finbert():
|
| 317 |
+
with _finbert_cache["lock"]:
|
| 318 |
+
if _finbert_cache["pipe"] is None:
|
| 319 |
+
from transformers import pipeline
|
| 320 |
+
_finbert_cache["pipe"] = pipeline("text-classification",
|
| 321 |
+
model="peejm/finbert-financial-sentiment",
|
| 322 |
+
device=-1, top_k=None)
|
| 323 |
+
return _finbert_cache["pipe"]
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _score_texts_finbert(texts: List[str]) -> Dict[str, Any]:
|
| 327 |
+
"""Run FinBERT over a list of texts, return aggregate sentiment metrics."""
|
| 328 |
+
if not texts:
|
| 329 |
+
return {"n": 0, "sentiment_net": 0.0, "direction": 0, "pos": 0, "neg": 0, "neu": 0}
|
| 330 |
+
pipe = _get_finbert()
|
| 331 |
+
texts = [t[:512] for t in texts if t and t.strip()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
if not texts:
|
| 333 |
+
return {"n": 0, "sentiment_net": 0.0, "direction": 0, "pos": 0, "neg": 0, "neu": 0}
|
| 334 |
+
results = pipe(texts, batch_size=8, truncation=True)
|
| 335 |
+
pos = neg = neu = 0
|
| 336 |
+
net = 0.0
|
| 337 |
+
for r in results:
|
| 338 |
+
# Result is list of {label, score} — take top
|
| 339 |
+
top = r[0] if isinstance(r, list) else r
|
| 340 |
+
label = str(top["label"]).lower()
|
| 341 |
+
score = float(top["score"])
|
| 342 |
+
if "pos" in label:
|
| 343 |
+
pos += 1
|
| 344 |
+
net += score
|
| 345 |
+
elif "neg" in label:
|
| 346 |
+
neg += 1
|
| 347 |
+
net -= score
|
| 348 |
+
else:
|
| 349 |
+
neu += 1
|
| 350 |
+
n = len(results)
|
| 351 |
+
mean_net = net / n if n > 0 else 0.0
|
| 352 |
+
direction = 1 if mean_net > 0.15 else (-1 if mean_net < -0.15 else 0)
|
| 353 |
+
return {"n": n, "sentiment_net": round(mean_net, 4), "direction": direction,
|
| 354 |
+
"pos": pos, "neg": neg, "neu": neu}
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def score_sentiment(text: str) -> dict:
|
| 358 |
+
"""Score single piece of text."""
|
| 359 |
try:
|
| 360 |
+
res = _score_texts_finbert([text])
|
| 361 |
+
return {"status": "ok", **res}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
except Exception as e:
|
| 363 |
+
return {"status": "error", "error": str(e)}
|
| 364 |
|
| 365 |
|
| 366 |
+
def score_sentiment_for_symbol(symbol: str, max_articles: int = 20) -> dict:
|
| 367 |
+
"""Fetch yfinance news and score via FinBERT."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
try:
|
| 369 |
+
import yfinance as yf
|
| 370 |
+
try:
|
| 371 |
+
from curl_cffi import requests as cffi_requests
|
| 372 |
+
session = cffi_requests.Session(impersonate="chrome")
|
| 373 |
+
except Exception:
|
| 374 |
+
session = None
|
| 375 |
t = yf.Ticker(symbol, session=session) if session else yf.Ticker(symbol)
|
| 376 |
+
news = []
|
| 377 |
+
try:
|
| 378 |
+
news = t.news or []
|
| 379 |
+
except Exception as e:
|
| 380 |
+
return {"status": "error", "symbol": symbol,
|
| 381 |
+
"error": f"yfinance news fetch failed: {e}"}
|
| 382 |
+
titles = []
|
| 383 |
+
for item in news[:max_articles]:
|
| 384 |
+
# yfinance news can have content nested under "content" key
|
| 385 |
+
if "content" in item and isinstance(item["content"], dict):
|
| 386 |
+
title = item["content"].get("title") or ""
|
| 387 |
+
desc = item["content"].get("description") or ""
|
| 388 |
+
else:
|
| 389 |
+
title = item.get("title", "")
|
| 390 |
+
desc = item.get("summary", "")
|
| 391 |
+
txt = f"{title}. {desc}".strip().strip(".")
|
| 392 |
+
if txt:
|
| 393 |
+
titles.append(txt)
|
| 394 |
+
res = _score_texts_finbert(titles)
|
| 395 |
+
return {"status": "ok", "symbol": symbol, "source": "yfinance_news",
|
| 396 |
+
"n_articles": res["n"], "sentiment_net": res["sentiment_net"],
|
| 397 |
+
"direction": res["direction"],
|
| 398 |
+
"pos": res["pos"], "neg": res["neg"], "neu": res["neu"]}
|
| 399 |
except Exception as e:
|
| 400 |
+
return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# ============================================================
|
| 404 |
+
# NEW: GDELT news sentiment (global macro/event stream)
|
| 405 |
+
# ============================================================
|
| 406 |
+
|
| 407 |
+
def news_gdelt_for_symbol(symbol: str, company_name: str = "", days: int = 3,
|
| 408 |
+
max_articles: int = 30) -> dict:
|
| 409 |
+
"""Fetch GDELT articles matching symbol/company, score sentiment.
|
| 410 |
+
Free, no key, 15-min refresh, 100+ languages (filtered to English)."""
|
| 411 |
+
try:
|
| 412 |
+
from gdeltdoc import GdeltDoc, Filters
|
| 413 |
+
# Query construction
|
| 414 |
+
# If company_name given, use it; else just symbol
|
| 415 |
+
keyword = company_name.strip() if company_name.strip() else symbol
|
| 416 |
+
timespan_map = {1: "1d", 2: "2d", 3: "3d", 7: "1w"}
|
| 417 |
+
timespan = timespan_map.get(int(days), f"{int(days)}d")
|
| 418 |
+
f = Filters(keyword=keyword, language="eng",
|
| 419 |
+
timespan=timespan, num_records=int(max_articles))
|
| 420 |
+
gd = GdeltDoc()
|
| 421 |
+
articles = gd.article_search(f)
|
| 422 |
+
if articles is None or len(articles) == 0:
|
| 423 |
+
return {"status": "ok", "symbol": symbol, "source": "gdelt",
|
| 424 |
+
"n_articles": 0, "sentiment_net": 0.0, "direction": 0,
|
| 425 |
+
"pos": 0, "neg": 0, "neu": 0, "top_domains": []}
|
| 426 |
+
titles = [t for t in articles["title"].tolist() if isinstance(t, str) and t]
|
| 427 |
+
# Deduplicate
|
| 428 |
+
seen = set()
|
| 429 |
+
deduped = []
|
| 430 |
+
for t in titles:
|
| 431 |
+
key = t[:120].lower()
|
| 432 |
+
if key not in seen:
|
| 433 |
+
seen.add(key)
|
| 434 |
+
deduped.append(t)
|
| 435 |
+
res = _score_texts_finbert(deduped)
|
| 436 |
+
# Top source domains
|
| 437 |
+
if "domain" in articles.columns:
|
| 438 |
+
top_domains = articles["domain"].value_counts().head(5).to_dict()
|
| 439 |
+
else:
|
| 440 |
+
top_domains = {}
|
| 441 |
+
return {"status": "ok", "symbol": symbol, "source": "gdelt",
|
| 442 |
+
"n_articles": res["n"], "sentiment_net": res["sentiment_net"],
|
| 443 |
+
"direction": res["direction"],
|
| 444 |
+
"pos": res["pos"], "neg": res["neg"], "neu": res["neu"],
|
| 445 |
+
"top_domains": top_domains,
|
| 446 |
+
"keyword_used": keyword, "timespan": timespan}
|
| 447 |
+
except Exception as e:
|
| 448 |
+
return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
# ============================================================
|
| 452 |
+
# NEW: Reddit retail sentiment (WSB + ISB + stocks + investing)
|
| 453 |
+
# ============================================================
|
| 454 |
+
|
| 455 |
+
_DEFAULT_SUBS = ["wallstreetbets", "stocks", "investing", "IndianStreetBets",
|
| 456 |
+
"DalalStreetTalks", "IndiaInvestments"]
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def _fetch_reddit_posts(sub: str, query: str, time_filter: str = "week",
|
| 460 |
+
limit: int = 25) -> list:
|
| 461 |
+
"""Fetch posts from Reddit public JSON API — no auth needed."""
|
| 462 |
+
import requests
|
| 463 |
+
url = f"https://www.reddit.com/r/{sub}/search.json"
|
| 464 |
+
params = {"q": query, "restrict_sr": "1", "sort": "top",
|
| 465 |
+
"t": time_filter, "limit": str(min(limit, 100))}
|
| 466 |
+
headers = {"User-Agent": "InvestmentOS/1.0 (ensemble analysis)"}
|
| 467 |
+
try:
|
| 468 |
+
r = requests.get(url, params=params, headers=headers, timeout=15)
|
| 469 |
+
if r.status_code != 200:
|
| 470 |
+
return []
|
| 471 |
+
data = r.json()
|
| 472 |
+
posts = []
|
| 473 |
+
for child in data.get("data", {}).get("children", []):
|
| 474 |
+
d = child.get("data", {})
|
| 475 |
+
posts.append({
|
| 476 |
+
"title": d.get("title", ""),
|
| 477 |
+
"selftext": d.get("selftext", "")[:1000],
|
| 478 |
+
"score": d.get("score", 0),
|
| 479 |
+
"num_comments": d.get("num_comments", 0),
|
| 480 |
+
"sub": sub,
|
| 481 |
+
"url": f"https://www.reddit.com{d.get('permalink', '')}",
|
| 482 |
+
})
|
| 483 |
+
return posts
|
| 484 |
+
except Exception:
|
| 485 |
+
return []
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def reddit_sentiment_for_symbol(symbol: str, subs_csv: str = "",
|
| 489 |
+
max_posts_per_sub: int = 20,
|
| 490 |
+
time_filter: str = "week") -> dict:
|
| 491 |
+
"""Search multiple subreddits for symbol mentions and score sentiment."""
|
| 492 |
try:
|
| 493 |
+
import concurrent.futures
|
| 494 |
+
subs = [s.strip() for s in (subs_csv or "").split(",") if s.strip()]
|
| 495 |
+
if not subs:
|
| 496 |
+
subs = _DEFAULT_SUBS
|
| 497 |
+
# Query: symbol with optional $ prefix to catch ticker mentions
|
| 498 |
+
query = f'"{symbol}" OR "${symbol}"'
|
| 499 |
+
|
| 500 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as ex:
|
| 501 |
+
futs = {ex.submit(_fetch_reddit_posts, s, query, time_filter, max_posts_per_sub): s
|
| 502 |
+
for s in subs}
|
| 503 |
+
all_posts = []
|
| 504 |
+
by_sub_count = {}
|
| 505 |
+
for fut in concurrent.futures.as_completed(futs):
|
| 506 |
+
sub = futs[fut]
|
| 507 |
+
try:
|
| 508 |
+
posts = fut.result()
|
| 509 |
+
except Exception:
|
| 510 |
+
posts = []
|
| 511 |
+
by_sub_count[sub] = len(posts)
|
| 512 |
+
all_posts.extend(posts)
|
| 513 |
+
|
| 514 |
+
# Build texts: weight higher-score posts by including selftext too
|
| 515 |
+
texts = []
|
| 516 |
+
for p in all_posts:
|
| 517 |
+
txt = p["title"]
|
| 518 |
+
if p["selftext"]:
|
| 519 |
+
txt = f"{p['title']}. {p['selftext'][:400]}"
|
| 520 |
+
if txt.strip():
|
| 521 |
+
texts.append(txt[:512])
|
| 522 |
+
|
| 523 |
+
if not texts:
|
| 524 |
+
return {"status": "ok", "symbol": symbol, "source": "reddit",
|
| 525 |
+
"n_mentions": 0, "sentiment_net": 0.0, "direction": 0,
|
| 526 |
+
"pos": 0, "neg": 0, "neu": 0, "by_sub": by_sub_count,
|
| 527 |
+
"subs_searched": subs}
|
| 528 |
+
|
| 529 |
+
res = _score_texts_finbert(texts)
|
| 530 |
+
# Attention metric: weighted score
|
| 531 |
+
total_score = sum(p["score"] for p in all_posts)
|
| 532 |
+
total_comments = sum(p["num_comments"] for p in all_posts)
|
| 533 |
+
|
| 534 |
+
return {"status": "ok", "symbol": symbol, "source": "reddit",
|
| 535 |
+
"n_mentions": res["n"],
|
| 536 |
+
"sentiment_net": res["sentiment_net"],
|
| 537 |
+
"direction": res["direction"],
|
| 538 |
+
"pos": res["pos"], "neg": res["neg"], "neu": res["neu"],
|
| 539 |
+
"by_sub": by_sub_count,
|
| 540 |
+
"total_upvotes": int(total_score),
|
| 541 |
+
"total_comments": int(total_comments),
|
| 542 |
+
"subs_searched": subs,
|
| 543 |
+
"query": query, "time_filter": time_filter}
|
| 544 |
except Exception as e:
|
| 545 |
+
return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]}
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
# ============================================================
|
| 549 |
+
# Gradio Blocks with MCP exposure
|
| 550 |
+
# ============================================================
|
| 551 |
+
|
| 552 |
+
with gr.Blocks(title="Investment OS Multi-Model Space") as demo:
|
| 553 |
+
gr.Markdown("# Investment OS: Kronos + Chronos + TimesFM + TiRex + MOMENT + FinBERT + GDELT + Reddit")
|
| 554 |
+
|
| 555 |
+
with gr.Tab("Kronos (OHLCV TSFM)"):
|
| 556 |
+
sym = gr.Textbox(label="Symbol", value="AAPL")
|
| 557 |
+
lb = gr.Number(label="Lookback", value=180)
|
| 558 |
+
pd_ = gr.Number(label="Pred days", value=30)
|
| 559 |
+
out = gr.JSON(label="Forecast")
|
| 560 |
+
gr.Button("Forecast").click(forecast, [sym, lb, pd_], out, api_name="forecast")
|
| 561 |
+
|
| 562 |
+
with gr.Tab("Chronos (generic TSFM)"):
|
| 563 |
+
s2 = gr.Textbox(label="Symbol", value="AAPL")
|
| 564 |
+
l2 = gr.Number(label="Lookback", value=180)
|
| 565 |
+
p2 = gr.Number(label="Pred days", value=30)
|
| 566 |
+
o2 = gr.JSON(label="Forecast")
|
| 567 |
+
gr.Button("Forecast").click(forecast_chronos, [s2, l2, p2], o2, api_name="forecast_chronos")
|
| 568 |
+
|
| 569 |
+
with gr.Tab("TimesFM 2.5 (transformers)"):
|
| 570 |
+
s3 = gr.Textbox(label="Symbol", value="AAPL")
|
| 571 |
+
l3 = gr.Number(label="Lookback", value=180)
|
| 572 |
+
p3 = gr.Number(label="Pred days", value=30)
|
| 573 |
+
o3 = gr.JSON(label="Forecast")
|
| 574 |
+
gr.Button("Forecast").click(forecast_timesfm, [s3, l3, p3], o3, api_name="forecast_timesfm")
|
| 575 |
+
|
| 576 |
+
with gr.Tab("TiRex (xLSTM TSFM) NEW"):
|
| 577 |
+
s4 = gr.Textbox(label="Symbol", value="AAPL")
|
| 578 |
+
l4 = gr.Number(label="Lookback", value=180)
|
| 579 |
+
p4 = gr.Number(label="Pred days", value=30)
|
| 580 |
+
o4 = gr.JSON(label="Forecast")
|
| 581 |
+
gr.Button("Forecast").click(forecast_tirex, [s4, l4, p4], o4, api_name="forecast_tirex")
|
| 582 |
+
|
| 583 |
+
with gr.Tab("MOMENT Anomaly NEW"):
|
| 584 |
+
s5 = gr.Textbox(label="Symbol", value="AAPL")
|
| 585 |
+
l5 = gr.Number(label="Lookback (min 512)", value=512)
|
| 586 |
+
o5 = gr.JSON(label="Anomaly analysis")
|
| 587 |
+
gr.Button("Detect").click(anomaly_moment, [s5, l5], o5, api_name="anomaly_moment")
|
| 588 |
+
|
| 589 |
+
with gr.Tab("FinBERT text"):
|
| 590 |
+
t6 = gr.Textbox(label="Text", value="The company reported record earnings.")
|
| 591 |
+
o6 = gr.JSON(label="Sentiment")
|
| 592 |
+
gr.Button("Score").click(score_sentiment, t6, o6, api_name="score_sentiment")
|
| 593 |
+
|
| 594 |
+
with gr.Tab("FinBERT yfinance news"):
|
| 595 |
+
s7 = gr.Textbox(label="Symbol", value="AAPL")
|
| 596 |
+
m7 = gr.Number(label="Max articles", value=20)
|
| 597 |
+
o7 = gr.JSON(label="Sentiment")
|
| 598 |
+
gr.Button("Score").click(score_sentiment_for_symbol, [s7, m7], o7, api_name="score_sentiment_for_symbol")
|
| 599 |
+
|
| 600 |
+
with gr.Tab("GDELT news NEW"):
|
| 601 |
+
s8 = gr.Textbox(label="Symbol", value="AAPL")
|
| 602 |
+
c8 = gr.Textbox(label="Company name (optional)", value="Apple")
|
| 603 |
+
d8 = gr.Number(label="Days", value=3)
|
| 604 |
+
m8 = gr.Number(label="Max articles", value=30)
|
| 605 |
+
o8 = gr.JSON(label="GDELT sentiment")
|
| 606 |
+
gr.Button("Fetch").click(news_gdelt_for_symbol, [s8, c8, d8, m8], o8, api_name="news_gdelt_for_symbol")
|
| 607 |
+
|
| 608 |
+
with gr.Tab("Reddit sentiment NEW"):
|
| 609 |
+
s9 = gr.Textbox(label="Symbol", value="AAPL")
|
| 610 |
+
sub9 = gr.Textbox(label="Subs CSV (blank = defaults)",
|
| 611 |
+
value="wallstreetbets,stocks,investing,IndianStreetBets,DalalStreetTalks,IndiaInvestments")
|
| 612 |
+
m9 = gr.Number(label="Max posts per sub", value=20)
|
| 613 |
+
t9 = gr.Textbox(label="Time filter", value="week")
|
| 614 |
+
o9 = gr.JSON(label="Reddit sentiment")
|
| 615 |
+
gr.Button("Fetch").click(reddit_sentiment_for_symbol, [s9, sub9, m9, t9], o9,
|
| 616 |
+
api_name="reddit_sentiment_for_symbol")
|
| 617 |
|
| 618 |
|
| 619 |
if __name__ == "__main__":
|
| 620 |
+
demo.launch(mcp_server=True, server_name="0.0.0.0", server_port=7860)
|