Spaces:
Running
Running
| """Kronos + Chronos + TimesFM + TiRex + MOMENT + FinBERT + GDELT + Reddit — Investment OS Space.""" | |
| from __future__ import annotations | |
| import os, sys, time, json, traceback, threading, warnings | |
| from typing import List, Optional, Tuple, Dict, Any | |
| warnings.filterwarnings("ignore") | |
| os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| # Disable TiRex custom CUDA kernels (we're on CPU) | |
| os.environ.setdefault("TIREX_NO_CUDA", "1") | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import gradio as gr | |
| # ============================================================ | |
| # Shared yfinance OHLC loader (new session each call to avoid race) | |
| # ============================================================ | |
| def _load_ohlc(symbol: str, lookback: int = 180) -> pd.DataFrame: | |
| import yfinance as yf | |
| try: | |
| from curl_cffi import requests as cffi_requests | |
| session = cffi_requests.Session(impersonate="chrome") | |
| except Exception: | |
| session = None | |
| end = pd.Timestamp.utcnow().tz_localize(None) | |
| start = end - pd.Timedelta(days=int(lookback * 2.2)) # account for weekends/holidays | |
| kwargs = dict(start=start.strftime("%Y-%m-%d"), end=(end + pd.Timedelta(days=1)).strftime("%Y-%m-%d"), | |
| interval="1d", progress=False, auto_adjust=False, actions=False, threads=False) | |
| if session is not None: | |
| kwargs["session"] = session | |
| df = yf.download(symbol, **kwargs) | |
| if df is None or len(df) == 0: | |
| raise RuntimeError(f"No data for {symbol}") | |
| if isinstance(df.columns, pd.MultiIndex): | |
| df.columns = df.columns.get_level_values(0) | |
| df = df.dropna().tail(lookback).reset_index() | |
| need = {"Open", "High", "Low", "Close", "Volume"} | |
| if not need.issubset(set(df.columns)): | |
| raise RuntimeError(f"Missing columns for {symbol}: got {list(df.columns)}") | |
| return df | |
| # ============================================================ | |
| # Model 1: Kronos (finance-native OHLCV foundation model) | |
| # ============================================================ | |
| _kronos_cache = {"model": None, "tok": None, "pred": None, "lock": threading.Lock()} | |
| def _get_kronos(): | |
| with _kronos_cache["lock"]: | |
| if _kronos_cache["pred"] is None: | |
| from model import Kronos, KronosTokenizer, KronosPredictor | |
| tok = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") | |
| mdl = Kronos.from_pretrained("NeoQuasar/Kronos-small") | |
| _kronos_cache["tok"] = tok | |
| _kronos_cache["model"] = mdl | |
| _kronos_cache["pred"] = KronosPredictor(model=mdl, tokenizer=tok, device="cpu", max_context=512) | |
| return _kronos_cache["pred"] | |
| def forecast(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict: | |
| try: | |
| df = _load_ohlc(symbol, lookback) | |
| pred = _get_kronos() | |
| x_df = df[["Open", "High", "Low", "Close", "Volume"]].copy() | |
| x_ts = pd.to_datetime(df["Date"]) | |
| last = x_ts.iloc[-1] | |
| y_ts = pd.date_range(start=last + pd.Timedelta(days=1), periods=pred_days, freq="B") | |
| out = pred.predict(df=x_df, x_timestamp=x_ts, y_timestamp=y_ts, pred_len=pred_days, T=1.0, top_p=0.9, sample_count=1, verbose=False) | |
| last_close = float(x_df["Close"].iloc[-1]) | |
| pred_close = float(out["close"].iloc[-1]) | |
| mean_close = float(out["close"].mean()) | |
| min_close = float(out["close"].min()) | |
| max_close = float(out["close"].max()) | |
| pct = (pred_close - last_close) / last_close * 100 | |
| direction = 1 if pct > 2 else (-1 if pct < -2 else 0) | |
| return {"status": "ok", "symbol": symbol, "model": "NeoQuasar/Kronos-small", | |
| "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4), | |
| "pct_change": round(pct, 3), "direction": direction, | |
| "n_lookback": int(len(x_df)), "pred_days": pred_days, | |
| "pred_mean_close": round(mean_close, 4), "pred_min_close": round(min_close, 4), | |
| "pred_max_close": round(max_close, 4)} | |
| except Exception as e: | |
| return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]} | |
| # ============================================================ | |
| # Model 2: Chronos-bolt-tiny (generic TSFM) | |
| # ============================================================ | |
| _chronos_cache = {"pipe": None, "lock": threading.Lock()} | |
| def _get_chronos(): | |
| with _chronos_cache["lock"]: | |
| if _chronos_cache["pipe"] is None: | |
| from chronos import BaseChronosPipeline | |
| _chronos_cache["pipe"] = BaseChronosPipeline.from_pretrained( | |
| "amazon/chronos-bolt-tiny", device_map="cpu", torch_dtype=torch.float32) | |
| return _chronos_cache["pipe"] | |
| def forecast_chronos(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict: | |
| try: | |
| df = _load_ohlc(symbol, lookback) | |
| closes = df["Close"].values.astype(np.float32) | |
| pipe = _get_chronos() | |
| ctx = torch.tensor(closes, dtype=torch.float32) | |
| quantiles, mean = pipe.predict_quantiles(context=ctx, prediction_length=int(pred_days), | |
| quantile_levels=[0.1, 0.5, 0.9]) | |
| mean_pred = mean[0].numpy() | |
| low_pred = quantiles[0, :, 0].numpy() | |
| high_pred = quantiles[0, :, 2].numpy() | |
| last_close = float(closes[-1]) | |
| pred_close = float(mean_pred[-1]) | |
| pct = (pred_close - last_close) / last_close * 100 | |
| direction = 1 if pct > 2 else (-1 if pct < -2 else 0) | |
| return {"status": "ok", "symbol": symbol, "model": "amazon/chronos-bolt-tiny", | |
| "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4), | |
| "pct_change": round(pct, 3), "direction": direction, | |
| "n_lookback": int(len(closes)), "pred_days": int(pred_days), | |
| "pred_mean_close": round(float(mean_pred.mean()), 4), | |
| "pred_low_close": round(float(low_pred.min()), 4), | |
| "pred_high_close": round(float(high_pred.max()), 4)} | |
| except Exception as e: | |
| return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]} | |
| # ============================================================ | |
| # Model 3: TimesFM 2.5 via transformers (UPGRADED from 2.0) | |
| # ============================================================ | |
| _timesfm_cache = {"model": None, "lock": threading.Lock()} | |
| def _get_timesfm(): | |
| with _timesfm_cache["lock"]: | |
| if _timesfm_cache["model"] is None: | |
| try: | |
| from transformers import TimesFm2_5ModelForPrediction | |
| m = TimesFm2_5ModelForPrediction.from_pretrained( | |
| "google/timesfm-2.5-200m-transformers") | |
| m = m.to(torch.float32).eval() | |
| _timesfm_cache["model"] = m | |
| _timesfm_cache["version"] = "2.5" | |
| except Exception: | |
| # Fallback to 2.0 if 2.5 unavailable in transformers version | |
| from transformers import TimesFmModelForPrediction | |
| m = TimesFmModelForPrediction.from_pretrained( | |
| "google/timesfm-2.0-500m-pytorch") | |
| m = m.to(torch.float32).eval() | |
| _timesfm_cache["model"] = m | |
| _timesfm_cache["version"] = "2.0" | |
| return _timesfm_cache["model"], _timesfm_cache["version"] | |
| def forecast_timesfm(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict: | |
| try: | |
| df = _load_ohlc(symbol, lookback) | |
| closes = df["Close"].values.astype(np.float32) | |
| model, ver = _get_timesfm() | |
| past = [torch.tensor(closes, dtype=torch.float32)] | |
| with torch.no_grad(): | |
| if ver == "2.5": | |
| outputs = model(past_values=past, forecast_context_len=1024) | |
| mean_pred = outputs.mean_predictions[0].float().cpu().numpy() | |
| else: | |
| # v2.0 transformers API | |
| freq = torch.tensor([0], dtype=torch.long) | |
| outputs = model(past_values=past, freq=freq, return_dict=True) | |
| mean_pred = outputs.mean_predictions[0].float().cpu().numpy() | |
| # Slice to pred_days | |
| horizon = min(int(pred_days), len(mean_pred)) | |
| mean_pred = mean_pred[:horizon] | |
| last_close = float(closes[-1]) | |
| pred_close = float(mean_pred[-1]) | |
| pct = (pred_close - last_close) / last_close * 100 | |
| direction = 1 if pct > 2 else (-1 if pct < -2 else 0) | |
| return {"status": "ok", "symbol": symbol, | |
| "model": f"google/timesfm-{ver}", | |
| "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4), | |
| "pct_change": round(pct, 3), "direction": direction, | |
| "n_lookback": int(len(closes)), "pred_days": horizon, | |
| "pred_mean_close": round(float(mean_pred.mean()), 4), | |
| "pred_min_close": round(float(mean_pred.min()), 4), | |
| "pred_max_close": round(float(mean_pred.max()), 4)} | |
| except Exception as e: | |
| return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]} | |
| # ============================================================ | |
| # Model 4 (NEW): TiRex (35M xLSTM TSFM, CPU experimental) | |
| # ============================================================ | |
| _tirex_cache = {"model": None, "lock": threading.Lock()} | |
| def _get_tirex(): | |
| with _tirex_cache["lock"]: | |
| if _tirex_cache["model"] is None: | |
| from tirex import load_model | |
| _tirex_cache["model"] = load_model("NX-AI/TiRex") | |
| return _tirex_cache["model"] | |
| def forecast_tirex(symbol: str, lookback: int = 180, pred_days: int = 30) -> dict: | |
| try: | |
| df = _load_ohlc(symbol, lookback) | |
| closes = df["Close"].values.astype(np.float32) | |
| model = _get_tirex() | |
| # TiRex expects (batch, seq_len) | |
| ctx = torch.tensor(closes, dtype=torch.float32).unsqueeze(0) | |
| with torch.no_grad(): | |
| result = model.forecast(context=ctx, prediction_length=int(pred_days)) | |
| # TiRex returns (quantiles, mean) tuple in newer versions | |
| if isinstance(result, tuple) and len(result) == 2: | |
| _, mean_pred = result | |
| else: | |
| mean_pred = result | |
| mean_arr = mean_pred[0].float().cpu().numpy() if hasattr(mean_pred, "cpu") else np.asarray(mean_pred)[0] | |
| # Check for NaN (TiRex CPU may degrade) | |
| if np.isnan(mean_arr).any(): | |
| return {"status": "error", "symbol": symbol, | |
| "error": "TiRex returned NaN (CPU mode is experimental)", | |
| "model": "NX-AI/TiRex"} | |
| last_close = float(closes[-1]) | |
| pred_close = float(mean_arr[-1]) | |
| pct = (pred_close - last_close) / last_close * 100 | |
| direction = 1 if pct > 2 else (-1 if pct < -2 else 0) | |
| return {"status": "ok", "symbol": symbol, "model": "NX-AI/TiRex", | |
| "last_close": round(last_close, 4), "predicted_close": round(pred_close, 4), | |
| "pct_change": round(pct, 3), "direction": direction, | |
| "n_lookback": int(len(closes)), "pred_days": int(pred_days), | |
| "pred_mean_close": round(float(mean_arr.mean()), 4), | |
| "pred_min_close": round(float(mean_arr.min()), 4), | |
| "pred_max_close": round(float(mean_arr.max()), 4)} | |
| except Exception as e: | |
| return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]} | |
| # ============================================================ | |
| # Model 5 (NEW): MOMENT-1-large as ANOMALY DETECTOR | |
| # (MOMENT forecasting needs training — anomaly/reconstruction is zero-shot) | |
| # ============================================================ | |
| _moment_cache = {"model": None, "lock": threading.Lock()} | |
| def _get_moment(): | |
| with _moment_cache["lock"]: | |
| if _moment_cache["model"] is None: | |
| from momentfm import MOMENTPipeline | |
| m = MOMENTPipeline.from_pretrained( | |
| "AutonLab/MOMENT-1-large", | |
| model_kwargs={"task_name": "reconstruction"}, | |
| ) | |
| m.init() | |
| m.eval() | |
| _moment_cache["model"] = m | |
| return _moment_cache["model"] | |
| def anomaly_moment(symbol: str, lookback: int = 512) -> dict: | |
| """Detects anomalies in recent price action via reconstruction error. | |
| Returns anomaly score (higher = more anomalous) and regime flag.""" | |
| try: | |
| # MOMENT requires exactly 512 timesteps | |
| df = _load_ohlc(symbol, max(lookback, 512)) | |
| closes = df["Close"].values.astype(np.float32)[-512:] | |
| if len(closes) < 512: | |
| # Pad | |
| padded = np.zeros(512, dtype=np.float32) | |
| padded[-len(closes):] = closes | |
| closes = padded | |
| model = _get_moment() | |
| # Normalize | |
| mean_, std_ = closes.mean(), closes.std() or 1.0 | |
| norm = (closes - mean_) / std_ | |
| # MOMENT expects (batch, n_channels, seq_len) | |
| x = torch.tensor(norm, dtype=torch.float32).unsqueeze(0).unsqueeze(0) | |
| mask = torch.ones_like(x[:, 0, :], dtype=torch.long) | |
| with torch.no_grad(): | |
| output = model(x_enc=x, input_mask=mask) | |
| recon = output.reconstruction[0, 0].cpu().numpy() | |
| # Anomaly score per timestep = squared error, normalized | |
| err = (norm - recon) ** 2 | |
| recent_err = float(err[-30:].mean()) # last 30 days | |
| baseline_err = float(err[:-30].mean()) if len(err) > 30 else recent_err | |
| ratio = recent_err / max(baseline_err, 1e-6) | |
| # Regime flag: 1=normal, 2=elevated, 3=anomaly | |
| if ratio > 2.5: | |
| regime = "anomaly" | |
| elif ratio > 1.5: | |
| regime = "elevated" | |
| else: | |
| regime = "normal" | |
| # Peak anomaly in last 30d | |
| peak_idx = int(np.argmax(err[-30:])) | |
| return {"status": "ok", "symbol": symbol, "model": "AutonLab/MOMENT-1-large", | |
| "recent_err": round(recent_err, 4), | |
| "baseline_err": round(baseline_err, 4), | |
| "err_ratio": round(ratio, 3), | |
| "regime": regime, | |
| "peak_anomaly_days_ago": 30 - peak_idx, | |
| "n_context": int(len(closes))} | |
| except Exception as e: | |
| return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]} | |
| # ============================================================ | |
| # Model 6: FinBERT sentiment (news via yfinance) | |
| # ============================================================ | |
| _finbert_cache = {"pipe": None, "lock": threading.Lock()} | |
| def _get_finbert(): | |
| with _finbert_cache["lock"]: | |
| if _finbert_cache["pipe"] is None: | |
| from transformers import pipeline | |
| _finbert_cache["pipe"] = pipeline("text-classification", | |
| model="peejm/finbert-financial-sentiment", | |
| device=-1, top_k=None) | |
| return _finbert_cache["pipe"] | |
| def _score_texts_finbert(texts: List[str]) -> Dict[str, Any]: | |
| """Run FinBERT over a list of texts, return aggregate sentiment metrics.""" | |
| if not texts: | |
| return {"n": 0, "sentiment_net": 0.0, "direction": 0, "pos": 0, "neg": 0, "neu": 0} | |
| pipe = _get_finbert() | |
| texts = [t[:512] for t in texts if t and t.strip()] | |
| if not texts: | |
| return {"n": 0, "sentiment_net": 0.0, "direction": 0, "pos": 0, "neg": 0, "neu": 0} | |
| results = pipe(texts, batch_size=8, truncation=True) | |
| pos = neg = neu = 0 | |
| net = 0.0 | |
| for r in results: | |
| # Result is list of {label, score} — take top | |
| top = r[0] if isinstance(r, list) else r | |
| label = str(top["label"]).lower() | |
| score = float(top["score"]) | |
| if "pos" in label: | |
| pos += 1 | |
| net += score | |
| elif "neg" in label: | |
| neg += 1 | |
| net -= score | |
| else: | |
| neu += 1 | |
| n = len(results) | |
| mean_net = net / n if n > 0 else 0.0 | |
| direction = 1 if mean_net > 0.15 else (-1 if mean_net < -0.15 else 0) | |
| return {"n": n, "sentiment_net": round(mean_net, 4), "direction": direction, | |
| "pos": pos, "neg": neg, "neu": neu} | |
| def score_sentiment(text: str) -> dict: | |
| """Score single piece of text.""" | |
| try: | |
| res = _score_texts_finbert([text]) | |
| return {"status": "ok", **res} | |
| except Exception as e: | |
| return {"status": "error", "error": str(e)} | |
| def score_sentiment_for_symbol(symbol: str, max_articles: int = 20) -> dict: | |
| """Fetch yfinance news and score via FinBERT.""" | |
| try: | |
| import yfinance as yf | |
| try: | |
| from curl_cffi import requests as cffi_requests | |
| session = cffi_requests.Session(impersonate="chrome") | |
| except Exception: | |
| session = None | |
| t = yf.Ticker(symbol, session=session) if session else yf.Ticker(symbol) | |
| news = [] | |
| try: | |
| news = t.news or [] | |
| except Exception as e: | |
| return {"status": "error", "symbol": symbol, | |
| "error": f"yfinance news fetch failed: {e}"} | |
| titles = [] | |
| for item in news[:max_articles]: | |
| # yfinance news can have content nested under "content" key | |
| if "content" in item and isinstance(item["content"], dict): | |
| title = item["content"].get("title") or "" | |
| desc = item["content"].get("description") or "" | |
| else: | |
| title = item.get("title", "") | |
| desc = item.get("summary", "") | |
| txt = f"{title}. {desc}".strip().strip(".") | |
| if txt: | |
| titles.append(txt) | |
| res = _score_texts_finbert(titles) | |
| return {"status": "ok", "symbol": symbol, "source": "yfinance_news", | |
| "n_articles": res["n"], "sentiment_net": res["sentiment_net"], | |
| "direction": res["direction"], | |
| "pos": res["pos"], "neg": res["neg"], "neu": res["neu"]} | |
| except Exception as e: | |
| return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]} | |
| # ============================================================ | |
| # NEW: GDELT news sentiment (global macro/event stream) | |
| # ============================================================ | |
| def news_gdelt_for_symbol(symbol: str, company_name: str = "", days: int = 3, | |
| max_articles: int = 30) -> dict: | |
| """Fetch GDELT articles matching symbol/company, score sentiment. | |
| Free, no key, 15-min refresh, 100+ languages (filtered to English).""" | |
| try: | |
| from gdeltdoc import GdeltDoc, Filters | |
| # Query construction | |
| # If company_name given, use it; else just symbol | |
| keyword = company_name.strip() if company_name.strip() else symbol | |
| timespan_map = {1: "1d", 2: "2d", 3: "3d", 7: "1w"} | |
| timespan = timespan_map.get(int(days), f"{int(days)}d") | |
| f = Filters(keyword=keyword, language="eng", | |
| timespan=timespan, num_records=int(max_articles)) | |
| gd = GdeltDoc() | |
| articles = gd.article_search(f) | |
| if articles is None or len(articles) == 0: | |
| return {"status": "ok", "symbol": symbol, "source": "gdelt", | |
| "n_articles": 0, "sentiment_net": 0.0, "direction": 0, | |
| "pos": 0, "neg": 0, "neu": 0, "top_domains": []} | |
| titles = [t for t in articles["title"].tolist() if isinstance(t, str) and t] | |
| # Deduplicate | |
| seen = set() | |
| deduped = [] | |
| for t in titles: | |
| key = t[:120].lower() | |
| if key not in seen: | |
| seen.add(key) | |
| deduped.append(t) | |
| res = _score_texts_finbert(deduped) | |
| # Top source domains | |
| if "domain" in articles.columns: | |
| top_domains = articles["domain"].value_counts().head(5).to_dict() | |
| else: | |
| top_domains = {} | |
| return {"status": "ok", "symbol": symbol, "source": "gdelt", | |
| "n_articles": res["n"], "sentiment_net": res["sentiment_net"], | |
| "direction": res["direction"], | |
| "pos": res["pos"], "neg": res["neg"], "neu": res["neu"], | |
| "top_domains": top_domains, | |
| "keyword_used": keyword, "timespan": timespan} | |
| except Exception as e: | |
| return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]} | |
| # ============================================================ | |
| # NEW: Reddit retail sentiment (WSB + ISB + stocks + investing) | |
| # ============================================================ | |
| _DEFAULT_SUBS = ["wallstreetbets", "stocks", "investing", "IndianStreetBets", | |
| "DalalStreetTalks", "IndiaInvestments"] | |
| def _fetch_reddit_posts(sub: str, query: str, time_filter: str = "week", | |
| limit: int = 25) -> list: | |
| """Fetch posts from Reddit public JSON API — no auth needed.""" | |
| import requests | |
| url = f"https://www.reddit.com/r/{sub}/search.json" | |
| params = {"q": query, "restrict_sr": "1", "sort": "top", | |
| "t": time_filter, "limit": str(min(limit, 100))} | |
| headers = {"User-Agent": "InvestmentOS/1.0 (ensemble analysis)"} | |
| try: | |
| r = requests.get(url, params=params, headers=headers, timeout=15) | |
| if r.status_code != 200: | |
| return [] | |
| data = r.json() | |
| posts = [] | |
| for child in data.get("data", {}).get("children", []): | |
| d = child.get("data", {}) | |
| posts.append({ | |
| "title": d.get("title", ""), | |
| "selftext": d.get("selftext", "")[:1000], | |
| "score": d.get("score", 0), | |
| "num_comments": d.get("num_comments", 0), | |
| "sub": sub, | |
| "url": f"https://www.reddit.com{d.get('permalink', '')}", | |
| }) | |
| return posts | |
| except Exception: | |
| return [] | |
| def reddit_sentiment_for_symbol(symbol: str, subs_csv: str = "", | |
| max_posts_per_sub: int = 20, | |
| time_filter: str = "week") -> dict: | |
| """Search multiple subreddits for symbol mentions and score sentiment.""" | |
| try: | |
| import concurrent.futures | |
| subs = [s.strip() for s in (subs_csv or "").split(",") if s.strip()] | |
| if not subs: | |
| subs = _DEFAULT_SUBS | |
| # Query: symbol with optional $ prefix to catch ticker mentions | |
| query = f'"{symbol}" OR "${symbol}"' | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=6) as ex: | |
| futs = {ex.submit(_fetch_reddit_posts, s, query, time_filter, max_posts_per_sub): s | |
| for s in subs} | |
| all_posts = [] | |
| by_sub_count = {} | |
| for fut in concurrent.futures.as_completed(futs): | |
| sub = futs[fut] | |
| try: | |
| posts = fut.result() | |
| except Exception: | |
| posts = [] | |
| by_sub_count[sub] = len(posts) | |
| all_posts.extend(posts) | |
| # Build texts: weight higher-score posts by including selftext too | |
| texts = [] | |
| for p in all_posts: | |
| txt = p["title"] | |
| if p["selftext"]: | |
| txt = f"{p['title']}. {p['selftext'][:400]}" | |
| if txt.strip(): | |
| texts.append(txt[:512]) | |
| if not texts: | |
| return {"status": "ok", "symbol": symbol, "source": "reddit", | |
| "n_mentions": 0, "sentiment_net": 0.0, "direction": 0, | |
| "pos": 0, "neg": 0, "neu": 0, "by_sub": by_sub_count, | |
| "subs_searched": subs} | |
| res = _score_texts_finbert(texts) | |
| # Attention metric: weighted score | |
| total_score = sum(p["score"] for p in all_posts) | |
| total_comments = sum(p["num_comments"] for p in all_posts) | |
| return {"status": "ok", "symbol": symbol, "source": "reddit", | |
| "n_mentions": res["n"], | |
| "sentiment_net": res["sentiment_net"], | |
| "direction": res["direction"], | |
| "pos": res["pos"], "neg": res["neg"], "neu": res["neu"], | |
| "by_sub": by_sub_count, | |
| "total_upvotes": int(total_score), | |
| "total_comments": int(total_comments), | |
| "subs_searched": subs, | |
| "query": query, "time_filter": time_filter} | |
| except Exception as e: | |
| return {"status": "error", "symbol": symbol, "error": str(e), "traceback": traceback.format_exc()[-800:]} | |
| # ============================================================ | |
| # Gradio Blocks with MCP exposure | |
| # ============================================================ | |
| with gr.Blocks(title="Investment OS Multi-Model Space") as demo: | |
| gr.Markdown("# Investment OS: Kronos + Chronos + TimesFM + TiRex + MOMENT + FinBERT + GDELT + Reddit") | |
| with gr.Tab("Kronos (OHLCV TSFM)"): | |
| sym = gr.Textbox(label="Symbol", value="AAPL") | |
| lb = gr.Number(label="Lookback", value=180) | |
| pd_ = gr.Number(label="Pred days", value=30) | |
| out = gr.JSON(label="Forecast") | |
| gr.Button("Forecast").click(forecast, [sym, lb, pd_], out, api_name="forecast") | |
| with gr.Tab("Chronos (generic TSFM)"): | |
| s2 = gr.Textbox(label="Symbol", value="AAPL") | |
| l2 = gr.Number(label="Lookback", value=180) | |
| p2 = gr.Number(label="Pred days", value=30) | |
| o2 = gr.JSON(label="Forecast") | |
| gr.Button("Forecast").click(forecast_chronos, [s2, l2, p2], o2, api_name="forecast_chronos") | |
| with gr.Tab("TimesFM 2.5 (transformers)"): | |
| s3 = gr.Textbox(label="Symbol", value="AAPL") | |
| l3 = gr.Number(label="Lookback", value=180) | |
| p3 = gr.Number(label="Pred days", value=30) | |
| o3 = gr.JSON(label="Forecast") | |
| gr.Button("Forecast").click(forecast_timesfm, [s3, l3, p3], o3, api_name="forecast_timesfm") | |
| with gr.Tab("TiRex (xLSTM TSFM) NEW"): | |
| s4 = gr.Textbox(label="Symbol", value="AAPL") | |
| l4 = gr.Number(label="Lookback", value=180) | |
| p4 = gr.Number(label="Pred days", value=30) | |
| o4 = gr.JSON(label="Forecast") | |
| gr.Button("Forecast").click(forecast_tirex, [s4, l4, p4], o4, api_name="forecast_tirex") | |
| with gr.Tab("MOMENT Anomaly NEW"): | |
| s5 = gr.Textbox(label="Symbol", value="AAPL") | |
| l5 = gr.Number(label="Lookback (min 512)", value=512) | |
| o5 = gr.JSON(label="Anomaly analysis") | |
| gr.Button("Detect").click(anomaly_moment, [s5, l5], o5, api_name="anomaly_moment") | |
| with gr.Tab("FinBERT text"): | |
| t6 = gr.Textbox(label="Text", value="The company reported record earnings.") | |
| o6 = gr.JSON(label="Sentiment") | |
| gr.Button("Score").click(score_sentiment, t6, o6, api_name="score_sentiment") | |
| with gr.Tab("FinBERT yfinance news"): | |
| s7 = gr.Textbox(label="Symbol", value="AAPL") | |
| m7 = gr.Number(label="Max articles", value=20) | |
| o7 = gr.JSON(label="Sentiment") | |
| gr.Button("Score").click(score_sentiment_for_symbol, [s7, m7], o7, api_name="score_sentiment_for_symbol") | |
| with gr.Tab("GDELT news NEW"): | |
| s8 = gr.Textbox(label="Symbol", value="AAPL") | |
| c8 = gr.Textbox(label="Company name (optional)", value="Apple") | |
| d8 = gr.Number(label="Days", value=3) | |
| m8 = gr.Number(label="Max articles", value=30) | |
| o8 = gr.JSON(label="GDELT sentiment") | |
| gr.Button("Fetch").click(news_gdelt_for_symbol, [s8, c8, d8, m8], o8, api_name="news_gdelt_for_symbol") | |
| with gr.Tab("Reddit sentiment NEW"): | |
| s9 = gr.Textbox(label="Symbol", value="AAPL") | |
| sub9 = gr.Textbox(label="Subs CSV (blank = defaults)", | |
| value="wallstreetbets,stocks,investing,IndianStreetBets,DalalStreetTalks,IndiaInvestments") | |
| m9 = gr.Number(label="Max posts per sub", value=20) | |
| t9 = gr.Textbox(label="Time filter", value="week") | |
| o9 = gr.JSON(label="Reddit sentiment") | |
| gr.Button("Fetch").click(reddit_sentiment_for_symbol, [s9, sub9, m9, t9], o9, | |
| api_name="reddit_sentiment_for_symbol") | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True, server_name="0.0.0.0", server_port=7860) | |