Minyans commited on
Commit
f244203
·
1 Parent(s): 202b84c

Make device and @spaces.GPU decorator conditional — supports both ZeroGPU and CPU-basic tiers

Browse files
Files changed (1) hide show
  1. app.py +23 -6
app.py CHANGED
@@ -31,12 +31,28 @@ if not hasattr(_hfh, 'HfFolder'):
31
 
32
  # ─────────────────────────────────────────────────────────────────────────────
33
 
34
- import spaces
35
  import gradio as gr
36
  import json
37
  import os
 
38
  import pandas as pd
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Kronos model source is in ./model/ (copied from repo)
41
  sys.path.insert(0, os.path.dirname(__file__))
42
  from model import Kronos, KronosTokenizer, KronosPredictor
@@ -46,14 +62,14 @@ MODEL_ID = "NeoQuasar/Kronos-mini"
46
  LOOKBACK = 800
47
  PRED_LEN = 3
48
 
49
- # Load weights at startup on CPU — @spaces.GPU moves computation to CUDA at call time
50
  print("Loading Kronos-mini weights...", flush=True)
51
  tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
52
  model = Kronos.from_pretrained(MODEL_ID)
53
  print("Weights loaded.", flush=True)
54
 
55
 
56
- @spaces.GPU
57
  def forecast(payload_json: str) -> str:
58
  """
59
  Input: JSON string — list of:
@@ -63,7 +79,7 @@ def forecast(payload_json: str) -> str:
63
  { "symbol": "MSFT.US",
64
  "forecast": [{"open":..,"high":..,"low":..,"close":..,"date":"YYYY-MM-DD"}, ...] }
65
  """
66
- predictor = KronosPredictor(model, tokenizer, device="cuda", max_context=2048, clip=5)
67
  payload = json.loads(payload_json)
68
 
69
  symbols, df_list, x_ts_list, y_ts_list = [], [], [], []
@@ -98,7 +114,7 @@ def forecast(payload_json: str) -> str:
98
  T=0.6,
99
  top_k=0,
100
  top_p=0.9,
101
- sample_count=5, # GPU is fast use full 5 samples
102
  verbose=False
103
  )
104
 
@@ -123,7 +139,8 @@ demo = gr.Interface(
123
  "**Kronos-mini** (NeoQuasar/Kronos-mini) — AAAI 2026 autoregressive transformer "
124
  "trained on 12B+ K-line records from 45 global exchanges.\n\n"
125
  "POST OHLCV data for up to 6 symbols → receive 3-day OHLC predictions.\n"
126
- f"Settings: lookback={LOOKBACK} bars | pred_len={PRED_LEN} | T=0.6 | n=5 samples"
 
127
  ),
128
  flagging_mode="never",
129
  api_name="predict"
 
31
 
32
  # ─────────────────────────────────────────────────────────────────────────────
33
 
 
34
  import gradio as gr
35
  import json
36
  import os
37
+ import torch
38
  import pandas as pd
39
 
40
+ # ── ZeroGPU / CPU-basic compatibility ────────────────────────────────────────
41
+ # On ZeroGPU spaces, `spaces` is importable and @spaces.GPU allocates an H200.
42
+ # On CPU-basic spaces, the `spaces` package is absent — we fall back gracefully.
43
+ try:
44
+ import spaces
45
+ _GPU_AVAILABLE = True
46
+ except ImportError:
47
+ _GPU_AVAILABLE = False
48
+
49
+ def _gpu_if_available(fn):
50
+ """Decorator: use @spaces.GPU on ZeroGPU, identity on CPU-basic."""
51
+ return spaces.GPU(fn) if _GPU_AVAILABLE else fn
52
+
53
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
54
+ print(f"Running on device: {DEVICE}", flush=True)
55
+
56
  # Kronos model source is in ./model/ (copied from repo)
57
  sys.path.insert(0, os.path.dirname(__file__))
58
  from model import Kronos, KronosTokenizer, KronosPredictor
 
62
  LOOKBACK = 800
63
  PRED_LEN = 3
64
 
65
+ # Load weights at startup (CPU)moved to GPU at call time if ZeroGPU
66
  print("Loading Kronos-mini weights...", flush=True)
67
  tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
68
  model = Kronos.from_pretrained(MODEL_ID)
69
  print("Weights loaded.", flush=True)
70
 
71
 
72
+ @_gpu_if_available
73
  def forecast(payload_json: str) -> str:
74
  """
75
  Input: JSON string — list of:
 
79
  { "symbol": "MSFT.US",
80
  "forecast": [{"open":..,"high":..,"low":..,"close":..,"date":"YYYY-MM-DD"}, ...] }
81
  """
82
+ predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=2048, clip=5)
83
  payload = json.loads(payload_json)
84
 
85
  symbols, df_list, x_ts_list, y_ts_list = [], [], [], []
 
114
  T=0.6,
115
  top_k=0,
116
  top_p=0.9,
117
+ sample_count=3 if DEVICE == "cpu" else 5, # 3 on CPU to keep latency reasonable
118
  verbose=False
119
  )
120
 
 
139
  "**Kronos-mini** (NeoQuasar/Kronos-mini) — AAAI 2026 autoregressive transformer "
140
  "trained on 12B+ K-line records from 45 global exchanges.\n\n"
141
  "POST OHLCV data for up to 6 symbols → receive 3-day OHLC predictions.\n"
142
+ f"Settings: lookback={LOOKBACK} bars | pred_len={PRED_LEN} | T=0.6 | "
143
+ f"device={DEVICE} | n={'5' if DEVICE == 'cuda' else '3'} samples"
144
  ),
145
  flagging_mode="never",
146
  api_name="predict"