File size: 5,540 Bytes
839fb4b
 
 
 
 
 
 
 
 
ead67a3
 
 
202b84c
 
 
ead67a3
 
 
 
202b84c
 
 
 
 
 
 
 
 
 
 
 
 
 
839fb4b
 
 
f244203
839fb4b
 
f244203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839fb4b
 
 
 
 
 
 
 
 
f244203
839fb4b
 
 
 
 
 
f244203
839fb4b
 
 
 
 
 
 
 
 
f244203
839fb4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f244203
839fb4b
 
 
 
 
d4ebbf3
839fb4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f244203
 
839fb4b
d4ebbf3
839fb4b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
HuggingFace Space: Kronos-mini 3-day price forecast endpoint.

Accepts a JSON payload of OHLCV data for multiple symbols,
returns 3-day OHLC forecasts using Kronos-mini on ZeroGPU (H200).

API: POST /predict with JSON string โ†’ returns JSON string
"""

import sys
import types

# โ”€โ”€ Compatibility shims (must run before any gradio import) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

# 1. Stub audioop/pyaudioop โ€” removed from Python 3.13 stdlib; pydub tries to import them.
for _mod in ('audioop', 'pyaudioop'):
    if _mod not in sys.modules:
        sys.modules[_mod] = types.ModuleType(_mod)

# 2. Stub HfFolder โ€” removed from huggingface_hub>=0.25; gradio 4.44.0 imports it.
import huggingface_hub as _hfh
if not hasattr(_hfh, 'HfFolder'):
    class _HfFolder:
        @classmethod
        def get_token(cls): return None
        @classmethod
        def save_token(cls, token): pass
        @classmethod
        def delete_token(cls): pass
    _hfh.HfFolder = _HfFolder

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

import gradio as gr
import json
import os
import torch
import pandas as pd

# โ”€โ”€ ZeroGPU / CPU-basic compatibility โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# On ZeroGPU spaces, `spaces` is importable and @spaces.GPU allocates an H200.
# On CPU-basic spaces, the `spaces` package is absent โ€” we fall back gracefully.
try:
    import spaces
    _GPU_AVAILABLE = True
except ImportError:
    _GPU_AVAILABLE = False

def _gpu_if_available(fn):
    """Decorator: use @spaces.GPU on ZeroGPU, identity on CPU-basic."""
    return spaces.GPU(fn) if _GPU_AVAILABLE else fn

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on device: {DEVICE}", flush=True)

# Kronos model source is in ./model/ (copied from repo)
sys.path.insert(0, os.path.dirname(__file__))
from model import Kronos, KronosTokenizer, KronosPredictor

TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-2k"
MODEL_ID     = "NeoQuasar/Kronos-mini"
LOOKBACK     = 800
PRED_LEN     = 3

# Load weights at startup (CPU) โ€” moved to GPU at call time if ZeroGPU
print("Loading Kronos-mini weights...", flush=True)
tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
model     = Kronos.from_pretrained(MODEL_ID)
print("Weights loaded.", flush=True)


@_gpu_if_available
def forecast(payload_json: str) -> str:
    """
    Input:  JSON string โ€” list of:
              { "symbol": "MSFT.US",
                "ohlcv": [{"open":..,"high":..,"low":..,"close":..,"volume":..,"amount":..,"timestamp":"YYYY-MM-DD"}, ...] }
    Output: JSON string โ€” list of:
              { "symbol": "MSFT.US",
                "forecast": [{"open":..,"high":..,"low":..,"close":..,"date":"YYYY-MM-DD"}, ...] }
    """
    predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=2048, clip=5)
    payload   = json.loads(payload_json)

    symbols, df_list, x_ts_list, y_ts_list = [], [], [], []
    for item in payload:
        df = pd.DataFrame(item['ohlcv'])
        df['timestamp'] = pd.to_datetime(df['timestamp'])
        df = df.sort_values('timestamp').reset_index(drop=True)
        for col in ['open', 'high', 'low', 'close', 'volume', 'amount']:
            df[col] = df[col].astype(float)

        if len(df) < LOOKBACK:
            continue  # skip symbols with insufficient history

        feat_df = df.iloc[-LOOKBACK:][['open', 'high', 'low', 'close', 'volume', 'amount']]
        x_ts    = df.iloc[-LOOKBACK:]['timestamp']
        y_ts    = pd.Series(pd.bdate_range(
            start=df['timestamp'].iloc[-1] + pd.Timedelta(days=1), periods=PRED_LEN))

        symbols.append(item['symbol'])
        df_list.append(feat_df)
        x_ts_list.append(x_ts)
        y_ts_list.append(y_ts)

    if not symbols:
        return json.dumps([])

    preds = predictor.predict_batch(
        df_list=df_list,
        x_timestamp_list=x_ts_list,
        y_timestamp_list=y_ts_list,
        pred_len=PRED_LEN,
        T=0.6,
        top_k=0,
        top_p=0.9,
        sample_count=3 if DEVICE == "cpu" else 5,  # 3 on CPU to keep latency reasonable
        verbose=False
    )

    results = []
    for sym, pred_df, y_ts in zip(symbols, preds, y_ts_list):
        pred_ohlc = pred_df[['open', 'high', 'low', 'close']].round(2).copy()
        pred_ohlc['date'] = [t.strftime('%Y-%m-%d') for t in y_ts]
        results.append({
            'symbol': sym,
            'forecast': pred_ohlc.to_dict(orient='records')
        })

    return json.dumps(results)


demo = gr.Interface(
    fn=forecast,
    inputs=gr.Textbox(label="OHLCV payload (JSON)", lines=4, placeholder='[{"symbol":"MSFT.US","ohlcv":[...]}]'),
    outputs=gr.Textbox(label="Forecast output (JSON)"),
    title="Kronos-mini ยท 3-Day Price Forecast",
    description=(
        "**Kronos-mini** (NeoQuasar/Kronos-mini) โ€” AAAI 2026 autoregressive transformer "
        "trained on 12B+ K-line records from 45 global exchanges.\n\n"
        "POST OHLCV data for up to 6 symbols โ†’ receive 3-day OHLC predictions.\n"
        f"Settings: lookback={LOOKBACK} bars | pred_len={PRED_LEN} | T=0.6 | "
        f"device={DEVICE} | n={'5' if DEVICE == 'cuda' else '3'} samples"
    ),
    flagging_mode="never",
    api_name="predict"
)

if __name__ == "__main__":
    demo.launch()