kronos-forecast / app.py
Minyans's picture
Make device and @spaces.GPU decorator conditional โ€” supports both ZeroGPU and CPU-basic tiers
f244203
"""
HuggingFace Space: Kronos-mini 3-day price forecast endpoint.
Accepts a JSON payload of OHLCV data for multiple symbols,
returns 3-day OHLC forecasts using Kronos-mini on ZeroGPU (H200).
API: POST /predict with JSON string โ†’ returns JSON string
"""
import sys
import types
# โ”€โ”€ Compatibility shims (must run before any gradio import) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# 1. Stub audioop/pyaudioop โ€” removed from Python 3.13 stdlib; pydub tries to import them.
for _mod in ('audioop', 'pyaudioop'):
if _mod not in sys.modules:
sys.modules[_mod] = types.ModuleType(_mod)
# 2. Stub HfFolder โ€” removed from huggingface_hub>=0.25; gradio 4.44.0 imports it.
import huggingface_hub as _hfh
if not hasattr(_hfh, 'HfFolder'):
class _HfFolder:
@classmethod
def get_token(cls): return None
@classmethod
def save_token(cls, token): pass
@classmethod
def delete_token(cls): pass
_hfh.HfFolder = _HfFolder
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
import gradio as gr
import json
import os
import torch
import pandas as pd
# โ”€โ”€ ZeroGPU / CPU-basic compatibility โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# On ZeroGPU spaces, `spaces` is importable and @spaces.GPU allocates an H200.
# On CPU-basic spaces, the `spaces` package is absent โ€” we fall back gracefully.
try:
import spaces
_GPU_AVAILABLE = True
except ImportError:
_GPU_AVAILABLE = False
def _gpu_if_available(fn):
"""Decorator: use @spaces.GPU on ZeroGPU, identity on CPU-basic."""
return spaces.GPU(fn) if _GPU_AVAILABLE else fn
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on device: {DEVICE}", flush=True)
# Kronos model source is in ./model/ (copied from repo)
sys.path.insert(0, os.path.dirname(__file__))
from model import Kronos, KronosTokenizer, KronosPredictor
TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-2k"
MODEL_ID = "NeoQuasar/Kronos-mini"
LOOKBACK = 800
PRED_LEN = 3
# Load weights at startup (CPU) โ€” moved to GPU at call time if ZeroGPU
print("Loading Kronos-mini weights...", flush=True)
tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
model = Kronos.from_pretrained(MODEL_ID)
print("Weights loaded.", flush=True)
@_gpu_if_available
def forecast(payload_json: str) -> str:
"""
Input: JSON string โ€” list of:
{ "symbol": "MSFT.US",
"ohlcv": [{"open":..,"high":..,"low":..,"close":..,"volume":..,"amount":..,"timestamp":"YYYY-MM-DD"}, ...] }
Output: JSON string โ€” list of:
{ "symbol": "MSFT.US",
"forecast": [{"open":..,"high":..,"low":..,"close":..,"date":"YYYY-MM-DD"}, ...] }
"""
predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=2048, clip=5)
payload = json.loads(payload_json)
symbols, df_list, x_ts_list, y_ts_list = [], [], [], []
for item in payload:
df = pd.DataFrame(item['ohlcv'])
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.sort_values('timestamp').reset_index(drop=True)
for col in ['open', 'high', 'low', 'close', 'volume', 'amount']:
df[col] = df[col].astype(float)
if len(df) < LOOKBACK:
continue # skip symbols with insufficient history
feat_df = df.iloc[-LOOKBACK:][['open', 'high', 'low', 'close', 'volume', 'amount']]
x_ts = df.iloc[-LOOKBACK:]['timestamp']
y_ts = pd.Series(pd.bdate_range(
start=df['timestamp'].iloc[-1] + pd.Timedelta(days=1), periods=PRED_LEN))
symbols.append(item['symbol'])
df_list.append(feat_df)
x_ts_list.append(x_ts)
y_ts_list.append(y_ts)
if not symbols:
return json.dumps([])
preds = predictor.predict_batch(
df_list=df_list,
x_timestamp_list=x_ts_list,
y_timestamp_list=y_ts_list,
pred_len=PRED_LEN,
T=0.6,
top_k=0,
top_p=0.9,
sample_count=3 if DEVICE == "cpu" else 5, # 3 on CPU to keep latency reasonable
verbose=False
)
results = []
for sym, pred_df, y_ts in zip(symbols, preds, y_ts_list):
pred_ohlc = pred_df[['open', 'high', 'low', 'close']].round(2).copy()
pred_ohlc['date'] = [t.strftime('%Y-%m-%d') for t in y_ts]
results.append({
'symbol': sym,
'forecast': pred_ohlc.to_dict(orient='records')
})
return json.dumps(results)
demo = gr.Interface(
fn=forecast,
inputs=gr.Textbox(label="OHLCV payload (JSON)", lines=4, placeholder='[{"symbol":"MSFT.US","ohlcv":[...]}]'),
outputs=gr.Textbox(label="Forecast output (JSON)"),
title="Kronos-mini ยท 3-Day Price Forecast",
description=(
"**Kronos-mini** (NeoQuasar/Kronos-mini) โ€” AAAI 2026 autoregressive transformer "
"trained on 12B+ K-line records from 45 global exchanges.\n\n"
"POST OHLCV data for up to 6 symbols โ†’ receive 3-day OHLC predictions.\n"
f"Settings: lookback={LOOKBACK} bars | pred_len={PRED_LEN} | T=0.6 | "
f"device={DEVICE} | n={'5' if DEVICE == 'cuda' else '3'} samples"
),
flagging_mode="never",
api_name="predict"
)
if __name__ == "__main__":
demo.launch()