""" Kronos Forecast API — FastAPI version (HF Space backend). Endpoints: GET / — landing page GET /health — liveness check POST /api/predict — full Kronos forecast for a ticker POST /api/spot — just the current/recent price for a ticker (cheap, no model) """ import os import sys import json from datetime import datetime from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, JSONResponse from pydantic import BaseModel import numpy as np import pandas as pd import yfinance as yf import torch # noqa: F401 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Kronos")) from model import Kronos, KronosTokenizer, KronosPredictor # noqa: E402 # --------------------------------------------------------------------------- # Model loading at startup # --------------------------------------------------------------------------- print("Loading Kronos tokenizer + model (this takes ~30s on first run)...") TOKENIZER = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-2k") MODEL = Kronos.from_pretrained("NeoQuasar/Kronos-mini") PREDICTOR = KronosPredictor(MODEL, TOKENIZER, device="cpu", max_context=512) print("Model ready.") INTERVAL_MAP = { "1h": {"yf_interval": "1h", "yf_period": "60d", "freq": "1H"}, "1d": {"yf_interval": "1d", "yf_period": "2y", "freq": "1D"}, } def fetch_ohlcv(ticker: str, interval: str = "1h") -> pd.DataFrame: cfg = INTERVAL_MAP[interval] df = yf.download( ticker, period=cfg["yf_period"], interval=cfg["yf_interval"], auto_adjust=False, progress=False, ) if df.empty: raise ValueError(f"No data for ticker {ticker!r} at interval {interval}") if isinstance(df.columns, pd.MultiIndex): df.columns = df.columns.get_level_values(0) df = df.rename( columns={ "Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume", } ) df = df[["open", "high", "low", "close", "volume"]].dropna() df["amount"] = df["close"] * df["volume"] df.index.name = "timestamp" return df def run_forecast( ticker: str = "SPY", interval: str = "1d", lookback: int = 200, horizon: int = 5, n_samples: int = 30, temperature: float = 1.0, top_p: float = 0.9, ): df = fetch_ohlcv(ticker, interval=interval) df = df.tail(lookback).reset_index() df = df.rename(columns={df.columns[0]: "timestamp"}) x_df = df[["open", "high", "low", "close", "volume", "amount"]] x_ts = pd.to_datetime(df["timestamp"]) freq = INTERVAL_MAP[interval]["freq"] last_ts = x_ts.iloc[-1] y_ts = pd.Series(pd.date_range(start=last_ts, periods=horizon + 1, freq=freq)[1:]) preds = [] for _ in range(n_samples): out = PREDICTOR.predict( df=x_df, x_timestamp=x_ts, y_timestamp=y_ts, pred_len=horizon, T=temperature, top_p=top_p, sample_count=1, verbose=False, ) preds.append(out["close"].values) preds = np.stack(preds, axis=0) mean = preds.mean(axis=0) low = np.percentile(preds, 10, axis=0) high = np.percentile(preds, 90, axis=0) last_close = float(x_df["close"].iloc[-1]) terminal = preds[:, -1] bullish_prob = float((terminal > last_close).mean()) recent_returns = np.diff(np.log(x_df["close"].values[-horizon:])) recent_vol = float(np.std(recent_returns)) if len(recent_returns) > 1 else 0.0 pred_returns = np.diff(np.log(preds), axis=1) pred_vols = np.std(pred_returns, axis=1) vol_expansion_prob = ( float((pred_vols > recent_vol).mean()) if recent_vol > 0 else 0.5 ) expected_change_pct = float((mean[-1] - last_close) / last_close * 100.0) history = [ { "t": ts.isoformat(), "open": float(o), "high": float(h), "low": float(l), "close": float(c), } for ts, o, h, l, c in zip( x_ts, x_df["open"], x_df["high"], x_df["low"], x_df["close"] ) ] forecast_mean = [ {"t": ts.isoformat(), "close": float(v)} for ts, v in zip(y_ts, mean) ] forecast_low = [ {"t": ts.isoformat(), "close": float(v)} for ts, v in zip(y_ts, low) ] forecast_high = [ {"t": ts.isoformat(), "close": float(v)} for ts, v in zip(y_ts, high) ] return { "ticker": ticker, "interval": interval, "generated_at": datetime.utcnow().isoformat() + "Z", "last_close": last_close, "history": history, "forecast_mean": forecast_mean, "forecast_low": forecast_low, "forecast_high": forecast_high, "metrics": { "bullish_prob": bullish_prob, "vol_expansion_prob": vol_expansion_prob, "expected_change_pct": expected_change_pct, "n_samples": n_samples, "horizon": horizon, "lookback": lookback, }, } # --------------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------------- app = FastAPI(title="Kronos Forecast API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) class PredictRequest(BaseModel): # {"data": [ticker, interval, lookback, horizon, n_samples]} data: list class SpotRequest(BaseModel): # {"data": [ticker, interval]} — interval optional, defaults to "1d" data: list @app.get("/", response_class=HTMLResponse) def root(): return """
Endpoints:
POST /api/predict — run a Kronos forecast (~30s)POST /api/spot — just current price (fast, no model)GET /health — liveness check