Spaces:
Sleeping
Sleeping
| """ | |
| Kronos Forecast API — FastAPI version (HF Space backend). | |
| Endpoints: | |
| GET / — landing page | |
| GET /health — liveness check | |
| POST /api/predict — full Kronos forecast for a ticker | |
| POST /api/spot — just the current/recent price for a ticker (cheap, no model) | |
| """ | |
| import os | |
| import sys | |
| import json | |
| from datetime import datetime | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import pandas as pd | |
| import yfinance as yf | |
| import torch # noqa: F401 | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Kronos")) | |
| from model import Kronos, KronosTokenizer, KronosPredictor # noqa: E402 | |
| # --------------------------------------------------------------------------- | |
| # Model loading at startup | |
| # --------------------------------------------------------------------------- | |
| print("Loading Kronos tokenizer + model (this takes ~30s on first run)...") | |
| TOKENIZER = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-2k") | |
| MODEL = Kronos.from_pretrained("NeoQuasar/Kronos-mini") | |
| PREDICTOR = KronosPredictor(MODEL, TOKENIZER, device="cpu", max_context=512) | |
| print("Model ready.") | |
| INTERVAL_MAP = { | |
| "1h": {"yf_interval": "1h", "yf_period": "60d", "freq": "1H"}, | |
| "1d": {"yf_interval": "1d", "yf_period": "2y", "freq": "1D"}, | |
| } | |
| def fetch_ohlcv(ticker: str, interval: str = "1h") -> pd.DataFrame: | |
| cfg = INTERVAL_MAP[interval] | |
| df = yf.download( | |
| ticker, | |
| period=cfg["yf_period"], | |
| interval=cfg["yf_interval"], | |
| auto_adjust=False, | |
| progress=False, | |
| ) | |
| if df.empty: | |
| raise ValueError(f"No data for ticker {ticker!r} at interval {interval}") | |
| if isinstance(df.columns, pd.MultiIndex): | |
| df.columns = df.columns.get_level_values(0) | |
| df = df.rename( | |
| columns={ | |
| "Open": "open", | |
| "High": "high", | |
| "Low": "low", | |
| "Close": "close", | |
| "Volume": "volume", | |
| } | |
| ) | |
| df = df[["open", "high", "low", "close", "volume"]].dropna() | |
| df["amount"] = df["close"] * df["volume"] | |
| df.index.name = "timestamp" | |
| return df | |
| def run_forecast( | |
| ticker: str = "SPY", | |
| interval: str = "1d", | |
| lookback: int = 200, | |
| horizon: int = 5, | |
| n_samples: int = 30, | |
| temperature: float = 1.0, | |
| top_p: float = 0.9, | |
| ): | |
| df = fetch_ohlcv(ticker, interval=interval) | |
| df = df.tail(lookback).reset_index() | |
| df = df.rename(columns={df.columns[0]: "timestamp"}) | |
| x_df = df[["open", "high", "low", "close", "volume", "amount"]] | |
| x_ts = pd.to_datetime(df["timestamp"]) | |
| freq = INTERVAL_MAP[interval]["freq"] | |
| last_ts = x_ts.iloc[-1] | |
| y_ts = pd.Series(pd.date_range(start=last_ts, periods=horizon + 1, freq=freq)[1:]) | |
| preds = [] | |
| for _ in range(n_samples): | |
| out = PREDICTOR.predict( | |
| df=x_df, | |
| x_timestamp=x_ts, | |
| y_timestamp=y_ts, | |
| pred_len=horizon, | |
| T=temperature, | |
| top_p=top_p, | |
| sample_count=1, | |
| verbose=False, | |
| ) | |
| preds.append(out["close"].values) | |
| preds = np.stack(preds, axis=0) | |
| mean = preds.mean(axis=0) | |
| low = np.percentile(preds, 10, axis=0) | |
| high = np.percentile(preds, 90, axis=0) | |
| last_close = float(x_df["close"].iloc[-1]) | |
| terminal = preds[:, -1] | |
| bullish_prob = float((terminal > last_close).mean()) | |
| recent_returns = np.diff(np.log(x_df["close"].values[-horizon:])) | |
| recent_vol = float(np.std(recent_returns)) if len(recent_returns) > 1 else 0.0 | |
| pred_returns = np.diff(np.log(preds), axis=1) | |
| pred_vols = np.std(pred_returns, axis=1) | |
| vol_expansion_prob = ( | |
| float((pred_vols > recent_vol).mean()) if recent_vol > 0 else 0.5 | |
| ) | |
| expected_change_pct = float((mean[-1] - last_close) / last_close * 100.0) | |
| history = [ | |
| { | |
| "t": ts.isoformat(), | |
| "open": float(o), | |
| "high": float(h), | |
| "low": float(l), | |
| "close": float(c), | |
| } | |
| for ts, o, h, l, c in zip( | |
| x_ts, x_df["open"], x_df["high"], x_df["low"], x_df["close"] | |
| ) | |
| ] | |
| forecast_mean = [ | |
| {"t": ts.isoformat(), "close": float(v)} for ts, v in zip(y_ts, mean) | |
| ] | |
| forecast_low = [ | |
| {"t": ts.isoformat(), "close": float(v)} for ts, v in zip(y_ts, low) | |
| ] | |
| forecast_high = [ | |
| {"t": ts.isoformat(), "close": float(v)} for ts, v in zip(y_ts, high) | |
| ] | |
| return { | |
| "ticker": ticker, | |
| "interval": interval, | |
| "generated_at": datetime.utcnow().isoformat() + "Z", | |
| "last_close": last_close, | |
| "history": history, | |
| "forecast_mean": forecast_mean, | |
| "forecast_low": forecast_low, | |
| "forecast_high": forecast_high, | |
| "metrics": { | |
| "bullish_prob": bullish_prob, | |
| "vol_expansion_prob": vol_expansion_prob, | |
| "expected_change_pct": expected_change_pct, | |
| "n_samples": n_samples, | |
| "horizon": horizon, | |
| "lookback": lookback, | |
| }, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # FastAPI app | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI(title="Kronos Forecast API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class PredictRequest(BaseModel): | |
| # {"data": [ticker, interval, lookback, horizon, n_samples]} | |
| data: list | |
| class SpotRequest(BaseModel): | |
| # {"data": [ticker, interval]} — interval optional, defaults to "1d" | |
| data: list | |
| def root(): | |
| return """ | |
| <html><head><title>Kronos Forecast API</title> | |
| <style> | |
| body { font-family: ui-monospace, monospace; background: #0e0e0c; color: #e8e8e6; | |
| padding: 40px; max-width: 720px; margin: auto; line-height: 1.6; } | |
| h1 { color: #ff6b00; } | |
| code { background: #1a1a17; padding: 2px 6px; color: #ff9b50; } | |
| pre { background: #16161300; border: 1px solid #26261f; padding: 16px; overflow-x: auto; } | |
| a { color: #ff6b00; } | |
| </style></head><body> | |
| <h1>▲ Kronos Forecast API</h1> | |
| <p>Endpoints:</p> | |
| <ul> | |
| <li><code>POST /api/predict</code> — run a Kronos forecast (~30s)</li> | |
| <li><code>POST /api/spot</code> — just current price (fast, no model)</li> | |
| <li><code>GET /health</code> — liveness check</li> | |
| </ul> | |
| <p><a href="/docs">Interactive API docs →</a></p> | |
| </body></html> | |
| """ | |
| def health(): | |
| return {"status": "ok", "model": "Kronos-mini", "device": "cpu"} | |
| def api_predict(req: PredictRequest): | |
| """Returns Gradio-compatible envelope so the existing frontend works.""" | |
| try: | |
| if len(req.data) < 5: | |
| raise ValueError( | |
| "Expected 5 args: [ticker, interval, lookback, horizon, n_samples]" | |
| ) | |
| ticker, interval, lookback, horizon, n_samples = req.data[:5] | |
| result = run_forecast( | |
| ticker=str(ticker).upper().strip(), | |
| interval=str(interval), | |
| lookback=int(lookback), | |
| horizon=int(horizon), | |
| n_samples=int(n_samples), | |
| ) | |
| return {"data": [json.dumps(result)]} | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"data": [json.dumps({"error": str(e)})]}, | |
| ) | |
| def api_spot(req: SpotRequest): | |
| """ | |
| Cheap endpoint: returns the most recent close + a few recent bars for a | |
| ticker. Used by the History view to compare predictions against actuals | |
| without burning a full Kronos forecast. | |
| """ | |
| try: | |
| if len(req.data) < 1: | |
| raise ValueError("Expected at least 1 arg: [ticker, interval?]") | |
| ticker = str(req.data[0]).upper().strip() | |
| interval = str(req.data[1]) if len(req.data) > 1 else "1d" | |
| df = fetch_ohlcv(ticker, interval=interval) | |
| # Return the last 30 bars so the client can compare predicted vs actual | |
| df = df.tail(30).reset_index() | |
| df = df.rename(columns={df.columns[0]: "timestamp"}) | |
| bars = [ | |
| { | |
| "t": pd.Timestamp(row["timestamp"]).isoformat(), | |
| "open": float(row["open"]), | |
| "high": float(row["high"]), | |
| "low": float(row["low"]), | |
| "close": float(row["close"]), | |
| } | |
| for _, row in df.iterrows() | |
| ] | |
| result = { | |
| "ticker": ticker, | |
| "interval": interval, | |
| "fetched_at": datetime.utcnow().isoformat() + "Z", | |
| "last_close": float(df["close"].iloc[-1]), | |
| "last_t": pd.Timestamp(df["timestamp"].iloc[-1]).isoformat(), | |
| "bars": bars, | |
| } | |
| return {"data": [json.dumps(result)]} | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"data": [json.dumps({"error": str(e)})]}, | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |