Spaces:
Running
Running
Make device and @spaces.GPU decorator conditional — supports both ZeroGPU and CPU-basic tiers
Browse files
app.py
CHANGED
|
@@ -31,12 +31,28 @@ if not hasattr(_hfh, 'HfFolder'):
|
|
| 31 |
|
| 32 |
# ─────────────────────────────────────────────────────────────────────────────
|
| 33 |
|
| 34 |
-
import spaces
|
| 35 |
import gradio as gr
|
| 36 |
import json
|
| 37 |
import os
|
|
|
|
| 38 |
import pandas as pd
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Kronos model source is in ./model/ (copied from repo)
|
| 41 |
sys.path.insert(0, os.path.dirname(__file__))
|
| 42 |
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
@@ -46,14 +62,14 @@ MODEL_ID = "NeoQuasar/Kronos-mini"
|
|
| 46 |
LOOKBACK = 800
|
| 47 |
PRED_LEN = 3
|
| 48 |
|
| 49 |
-
# Load weights at startup
|
| 50 |
print("Loading Kronos-mini weights...", flush=True)
|
| 51 |
tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
|
| 52 |
model = Kronos.from_pretrained(MODEL_ID)
|
| 53 |
print("Weights loaded.", flush=True)
|
| 54 |
|
| 55 |
|
| 56 |
-
@
|
| 57 |
def forecast(payload_json: str) -> str:
|
| 58 |
"""
|
| 59 |
Input: JSON string — list of:
|
|
@@ -63,7 +79,7 @@ def forecast(payload_json: str) -> str:
|
|
| 63 |
{ "symbol": "MSFT.US",
|
| 64 |
"forecast": [{"open":..,"high":..,"low":..,"close":..,"date":"YYYY-MM-DD"}, ...] }
|
| 65 |
"""
|
| 66 |
-
predictor = KronosPredictor(model, tokenizer, device=
|
| 67 |
payload = json.loads(payload_json)
|
| 68 |
|
| 69 |
symbols, df_list, x_ts_list, y_ts_list = [], [], [], []
|
|
@@ -98,7 +114,7 @@ def forecast(payload_json: str) -> str:
|
|
| 98 |
T=0.6,
|
| 99 |
top_k=0,
|
| 100 |
top_p=0.9,
|
| 101 |
-
sample_count=5,
|
| 102 |
verbose=False
|
| 103 |
)
|
| 104 |
|
|
@@ -123,7 +139,8 @@ demo = gr.Interface(
|
|
| 123 |
"**Kronos-mini** (NeoQuasar/Kronos-mini) — AAAI 2026 autoregressive transformer "
|
| 124 |
"trained on 12B+ K-line records from 45 global exchanges.\n\n"
|
| 125 |
"POST OHLCV data for up to 6 symbols → receive 3-day OHLC predictions.\n"
|
| 126 |
-
f"Settings: lookback={LOOKBACK} bars | pred_len={PRED_LEN} | T=0.6 |
|
|
|
|
| 127 |
),
|
| 128 |
flagging_mode="never",
|
| 129 |
api_name="predict"
|
|
|
|
| 31 |
|
| 32 |
# ─────────────────────────────────────────────────────────────────────────────
|
| 33 |
|
|
|
|
| 34 |
import gradio as gr
|
| 35 |
import json
|
| 36 |
import os
|
| 37 |
+
import torch
|
| 38 |
import pandas as pd
|
| 39 |
|
| 40 |
+
# ── ZeroGPU / CPU-basic compatibility ────────────────────────────────────────
|
| 41 |
+
# On ZeroGPU spaces, `spaces` is importable and @spaces.GPU allocates an H200.
|
| 42 |
+
# On CPU-basic spaces, the `spaces` package is absent — we fall back gracefully.
|
| 43 |
+
try:
|
| 44 |
+
import spaces
|
| 45 |
+
_GPU_AVAILABLE = True
|
| 46 |
+
except ImportError:
|
| 47 |
+
_GPU_AVAILABLE = False
|
| 48 |
+
|
| 49 |
+
def _gpu_if_available(fn):
|
| 50 |
+
"""Decorator: use @spaces.GPU on ZeroGPU, identity on CPU-basic."""
|
| 51 |
+
return spaces.GPU(fn) if _GPU_AVAILABLE else fn
|
| 52 |
+
|
| 53 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 54 |
+
print(f"Running on device: {DEVICE}", flush=True)
|
| 55 |
+
|
| 56 |
# Kronos model source is in ./model/ (copied from repo)
|
| 57 |
sys.path.insert(0, os.path.dirname(__file__))
|
| 58 |
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
|
|
| 62 |
LOOKBACK = 800
|
| 63 |
PRED_LEN = 3
|
| 64 |
|
| 65 |
+
# Load weights at startup (CPU) — moved to GPU at call time if ZeroGPU
|
| 66 |
print("Loading Kronos-mini weights...", flush=True)
|
| 67 |
tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
|
| 68 |
model = Kronos.from_pretrained(MODEL_ID)
|
| 69 |
print("Weights loaded.", flush=True)
|
| 70 |
|
| 71 |
|
| 72 |
+
@_gpu_if_available
|
| 73 |
def forecast(payload_json: str) -> str:
|
| 74 |
"""
|
| 75 |
Input: JSON string — list of:
|
|
|
|
| 79 |
{ "symbol": "MSFT.US",
|
| 80 |
"forecast": [{"open":..,"high":..,"low":..,"close":..,"date":"YYYY-MM-DD"}, ...] }
|
| 81 |
"""
|
| 82 |
+
predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=2048, clip=5)
|
| 83 |
payload = json.loads(payload_json)
|
| 84 |
|
| 85 |
symbols, df_list, x_ts_list, y_ts_list = [], [], [], []
|
|
|
|
| 114 |
T=0.6,
|
| 115 |
top_k=0,
|
| 116 |
top_p=0.9,
|
| 117 |
+
sample_count=3 if DEVICE == "cpu" else 5, # 3 on CPU to keep latency reasonable
|
| 118 |
verbose=False
|
| 119 |
)
|
| 120 |
|
|
|
|
| 139 |
"**Kronos-mini** (NeoQuasar/Kronos-mini) — AAAI 2026 autoregressive transformer "
|
| 140 |
"trained on 12B+ K-line records from 45 global exchanges.\n\n"
|
| 141 |
"POST OHLCV data for up to 6 symbols → receive 3-day OHLC predictions.\n"
|
| 142 |
+
f"Settings: lookback={LOOKBACK} bars | pred_len={PRED_LEN} | T=0.6 | "
|
| 143 |
+
f"device={DEVICE} | n={'5' if DEVICE == 'cuda' else '3'} samples"
|
| 144 |
),
|
| 145 |
flagging_mode="never",
|
| 146 |
api_name="predict"
|