""" Kronos Forecast API — FastAPI version (HF Space backend). Endpoints: GET / — landing page GET /health — liveness check POST /api/predict — full Kronos forecast for a ticker POST /api/spot — just the current/recent price for a ticker (cheap, no model) """ import os import sys import json from datetime import datetime from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, JSONResponse from pydantic import BaseModel import numpy as np import pandas as pd import yfinance as yf import torch # noqa: F401 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Kronos")) from model import Kronos, KronosTokenizer, KronosPredictor # noqa: E402 # --------------------------------------------------------------------------- # Model loading at startup # --------------------------------------------------------------------------- print("Loading Kronos tokenizer + model (this takes ~30s on first run)...") TOKENIZER = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-2k") MODEL = Kronos.from_pretrained("NeoQuasar/Kronos-mini") PREDICTOR = KronosPredictor(MODEL, TOKENIZER, device="cpu", max_context=512) print("Model ready.") INTERVAL_MAP = { "1h": {"yf_interval": "1h", "yf_period": "60d", "freq": "1H"}, "1d": {"yf_interval": "1d", "yf_period": "2y", "freq": "1D"}, } def fetch_ohlcv(ticker: str, interval: str = "1h") -> pd.DataFrame: cfg = INTERVAL_MAP[interval] df = yf.download( ticker, period=cfg["yf_period"], interval=cfg["yf_interval"], auto_adjust=False, progress=False, ) if df.empty: raise ValueError(f"No data for ticker {ticker!r} at interval {interval}") if isinstance(df.columns, pd.MultiIndex): df.columns = df.columns.get_level_values(0) df = df.rename( columns={ "Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume", } ) df = df[["open", "high", "low", "close", "volume"]].dropna() df["amount"] = df["close"] * df["volume"] df.index.name = "timestamp" return df def run_forecast( ticker: str = "SPY", interval: str = "1d", lookback: int = 200, horizon: int = 5, n_samples: int = 30, temperature: float = 1.0, top_p: float = 0.9, ): df = fetch_ohlcv(ticker, interval=interval) df = df.tail(lookback).reset_index() df = df.rename(columns={df.columns[0]: "timestamp"}) x_df = df[["open", "high", "low", "close", "volume", "amount"]] x_ts = pd.to_datetime(df["timestamp"]) freq = INTERVAL_MAP[interval]["freq"] last_ts = x_ts.iloc[-1] y_ts = pd.Series(pd.date_range(start=last_ts, periods=horizon + 1, freq=freq)[1:]) preds = [] for _ in range(n_samples): out = PREDICTOR.predict( df=x_df, x_timestamp=x_ts, y_timestamp=y_ts, pred_len=horizon, T=temperature, top_p=top_p, sample_count=1, verbose=False, ) preds.append(out["close"].values) preds = np.stack(preds, axis=0) mean = preds.mean(axis=0) low = np.percentile(preds, 10, axis=0) high = np.percentile(preds, 90, axis=0) last_close = float(x_df["close"].iloc[-1]) terminal = preds[:, -1] bullish_prob = float((terminal > last_close).mean()) recent_returns = np.diff(np.log(x_df["close"].values[-horizon:])) recent_vol = float(np.std(recent_returns)) if len(recent_returns) > 1 else 0.0 pred_returns = np.diff(np.log(preds), axis=1) pred_vols = np.std(pred_returns, axis=1) vol_expansion_prob = ( float((pred_vols > recent_vol).mean()) if recent_vol > 0 else 0.5 ) expected_change_pct = float((mean[-1] - last_close) / last_close * 100.0) history = [ { "t": ts.isoformat(), "open": float(o), "high": float(h), "low": float(l), "close": float(c), } for ts, o, h, l, c in zip( x_ts, x_df["open"], x_df["high"], x_df["low"], x_df["close"] ) ] forecast_mean = [ {"t": ts.isoformat(), "close": float(v)} for ts, v in zip(y_ts, mean) ] forecast_low = [ {"t": ts.isoformat(), "close": float(v)} for ts, v in zip(y_ts, low) ] forecast_high = [ {"t": ts.isoformat(), "close": float(v)} for ts, v in zip(y_ts, high) ] return { "ticker": ticker, "interval": interval, "generated_at": datetime.utcnow().isoformat() + "Z", "last_close": last_close, "history": history, "forecast_mean": forecast_mean, "forecast_low": forecast_low, "forecast_high": forecast_high, "metrics": { "bullish_prob": bullish_prob, "vol_expansion_prob": vol_expansion_prob, "expected_change_pct": expected_change_pct, "n_samples": n_samples, "horizon": horizon, "lookback": lookback, }, } # --------------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------------- app = FastAPI(title="Kronos Forecast API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) class PredictRequest(BaseModel): # {"data": [ticker, interval, lookback, horizon, n_samples]} data: list class SpotRequest(BaseModel): # {"data": [ticker, interval]} — interval optional, defaults to "1d" data: list @app.get("/", response_class=HTMLResponse) def root(): return """ Kronos Forecast API

▲ Kronos Forecast API

Endpoints:

Interactive API docs →

""" @app.get("/health") def health(): return {"status": "ok", "model": "Kronos-mini", "device": "cpu"} @app.post("/api/predict") def api_predict(req: PredictRequest): """Returns Gradio-compatible envelope so the existing frontend works.""" try: if len(req.data) < 5: raise ValueError( "Expected 5 args: [ticker, interval, lookback, horizon, n_samples]" ) ticker, interval, lookback, horizon, n_samples = req.data[:5] result = run_forecast( ticker=str(ticker).upper().strip(), interval=str(interval), lookback=int(lookback), horizon=int(horizon), n_samples=int(n_samples), ) return {"data": [json.dumps(result)]} except Exception as e: return JSONResponse( status_code=500, content={"data": [json.dumps({"error": str(e)})]}, ) @app.post("/api/spot") def api_spot(req: SpotRequest): """ Cheap endpoint: returns the most recent close + a few recent bars for a ticker. Used by the History view to compare predictions against actuals without burning a full Kronos forecast. """ try: if len(req.data) < 1: raise ValueError("Expected at least 1 arg: [ticker, interval?]") ticker = str(req.data[0]).upper().strip() interval = str(req.data[1]) if len(req.data) > 1 else "1d" df = fetch_ohlcv(ticker, interval=interval) # Return the last 30 bars so the client can compare predicted vs actual df = df.tail(30).reset_index() df = df.rename(columns={df.columns[0]: "timestamp"}) bars = [ { "t": pd.Timestamp(row["timestamp"]).isoformat(), "open": float(row["open"]), "high": float(row["high"]), "low": float(row["low"]), "close": float(row["close"]), } for _, row in df.iterrows() ] result = { "ticker": ticker, "interval": interval, "fetched_at": datetime.utcnow().isoformat() + "Z", "last_close": float(df["close"].iloc[-1]), "last_t": pd.Timestamp(df["timestamp"].iloc[-1]).isoformat(), "bars": bars, } return {"data": [json.dumps(result)]} except Exception as e: return JSONResponse( status_code=500, content={"data": [json.dumps({"error": str(e)})]}, ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)