Spaces:
Running
Running
| """ | |
| HuggingFace Space: Kronos-mini 3-day price forecast endpoint. | |
| Accepts a JSON payload of OHLCV data for multiple symbols, | |
| returns 3-day OHLC forecasts using Kronos-mini on ZeroGPU (H200). | |
| API: POST /predict with JSON string โ returns JSON string | |
| """ | |
| import sys | |
| import types | |
| # โโ Compatibility shims (must run before any gradio import) โโโโโโโโโโโโโโโโโโโ | |
| # 1. Stub audioop/pyaudioop โ removed from Python 3.13 stdlib; pydub tries to import them. | |
| for _mod in ('audioop', 'pyaudioop'): | |
| if _mod not in sys.modules: | |
| sys.modules[_mod] = types.ModuleType(_mod) | |
| # 2. Stub HfFolder โ removed from huggingface_hub>=0.25; gradio 4.44.0 imports it. | |
| import huggingface_hub as _hfh | |
| if not hasattr(_hfh, 'HfFolder'): | |
| class _HfFolder: | |
| def get_token(cls): return None | |
| def save_token(cls, token): pass | |
| def delete_token(cls): pass | |
| _hfh.HfFolder = _HfFolder | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| import gradio as gr | |
| import json | |
| import os | |
| import torch | |
| import pandas as pd | |
| # โโ ZeroGPU / CPU-basic compatibility โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # On ZeroGPU spaces, `spaces` is importable and @spaces.GPU allocates an H200. | |
| # On CPU-basic spaces, the `spaces` package is absent โ we fall back gracefully. | |
| try: | |
| import spaces | |
| _GPU_AVAILABLE = True | |
| except ImportError: | |
| _GPU_AVAILABLE = False | |
| def _gpu_if_available(fn): | |
| """Decorator: use @spaces.GPU on ZeroGPU, identity on CPU-basic.""" | |
| return spaces.GPU(fn) if _GPU_AVAILABLE else fn | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Running on device: {DEVICE}", flush=True) | |
| # Kronos model source is in ./model/ (copied from repo) | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from model import Kronos, KronosTokenizer, KronosPredictor | |
| TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-2k" | |
| MODEL_ID = "NeoQuasar/Kronos-mini" | |
| LOOKBACK = 800 | |
| PRED_LEN = 3 | |
| # Load weights at startup (CPU) โ moved to GPU at call time if ZeroGPU | |
| print("Loading Kronos-mini weights...", flush=True) | |
| tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID) | |
| model = Kronos.from_pretrained(MODEL_ID) | |
| print("Weights loaded.", flush=True) | |
| def forecast(payload_json: str) -> str: | |
| """ | |
| Input: JSON string โ list of: | |
| { "symbol": "MSFT.US", | |
| "ohlcv": [{"open":..,"high":..,"low":..,"close":..,"volume":..,"amount":..,"timestamp":"YYYY-MM-DD"}, ...] } | |
| Output: JSON string โ list of: | |
| { "symbol": "MSFT.US", | |
| "forecast": [{"open":..,"high":..,"low":..,"close":..,"date":"YYYY-MM-DD"}, ...] } | |
| """ | |
| predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=2048, clip=5) | |
| payload = json.loads(payload_json) | |
| symbols, df_list, x_ts_list, y_ts_list = [], [], [], [] | |
| for item in payload: | |
| df = pd.DataFrame(item['ohlcv']) | |
| df['timestamp'] = pd.to_datetime(df['timestamp']) | |
| df = df.sort_values('timestamp').reset_index(drop=True) | |
| for col in ['open', 'high', 'low', 'close', 'volume', 'amount']: | |
| df[col] = df[col].astype(float) | |
| if len(df) < LOOKBACK: | |
| continue # skip symbols with insufficient history | |
| feat_df = df.iloc[-LOOKBACK:][['open', 'high', 'low', 'close', 'volume', 'amount']] | |
| x_ts = df.iloc[-LOOKBACK:]['timestamp'] | |
| y_ts = pd.Series(pd.bdate_range( | |
| start=df['timestamp'].iloc[-1] + pd.Timedelta(days=1), periods=PRED_LEN)) | |
| symbols.append(item['symbol']) | |
| df_list.append(feat_df) | |
| x_ts_list.append(x_ts) | |
| y_ts_list.append(y_ts) | |
| if not symbols: | |
| return json.dumps([]) | |
| preds = predictor.predict_batch( | |
| df_list=df_list, | |
| x_timestamp_list=x_ts_list, | |
| y_timestamp_list=y_ts_list, | |
| pred_len=PRED_LEN, | |
| T=0.6, | |
| top_k=0, | |
| top_p=0.9, | |
| sample_count=3 if DEVICE == "cpu" else 5, # 3 on CPU to keep latency reasonable | |
| verbose=False | |
| ) | |
| results = [] | |
| for sym, pred_df, y_ts in zip(symbols, preds, y_ts_list): | |
| pred_ohlc = pred_df[['open', 'high', 'low', 'close']].round(2).copy() | |
| pred_ohlc['date'] = [t.strftime('%Y-%m-%d') for t in y_ts] | |
| results.append({ | |
| 'symbol': sym, | |
| 'forecast': pred_ohlc.to_dict(orient='records') | |
| }) | |
| return json.dumps(results) | |
| demo = gr.Interface( | |
| fn=forecast, | |
| inputs=gr.Textbox(label="OHLCV payload (JSON)", lines=4, placeholder='[{"symbol":"MSFT.US","ohlcv":[...]}]'), | |
| outputs=gr.Textbox(label="Forecast output (JSON)"), | |
| title="Kronos-mini ยท 3-Day Price Forecast", | |
| description=( | |
| "**Kronos-mini** (NeoQuasar/Kronos-mini) โ AAAI 2026 autoregressive transformer " | |
| "trained on 12B+ K-line records from 45 global exchanges.\n\n" | |
| "POST OHLCV data for up to 6 symbols โ receive 3-day OHLC predictions.\n" | |
| f"Settings: lookback={LOOKBACK} bars | pred_len={PRED_LEN} | T=0.6 | " | |
| f"device={DEVICE} | n={'5' if DEVICE == 'cuda' else '3'} samples" | |
| ), | |
| flagging_mode="never", | |
| api_name="predict" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |