import os import subprocess # Download Kronos model code at startup if not os.path.exists("model"): subprocess.run([ "git", "clone", "--depth=1", "https://github.com/shiyu-coder/Kronos", "kronos_tmp" ]) os.rename("kronos_tmp/model", "model") subprocess.run(["rm", "-rf", "kronos_tmp"]) import spaces import gradio as gr import pandas as pd import plotly.graph_objects as go from model import Kronos, KronosTokenizer, KronosPredictor tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") model = Kronos.from_pretrained("NeoQuasar/Kronos-base") predictor = KronosPredictor(model, tokenizer, device="cuda", max_context=512) @spaces.GPU def forecast(file, pred_len, temperature, top_p): df = pd.read_csv(file.name) df['timestamps'] = pd.to_datetime(df['timestamps']) lookback = min(400, len(df) - int(pred_len)) x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']] x_ts = df.loc[:lookback-1, 'timestamps'] y_ts = df.loc[lookback:lookback+int(pred_len)-1, 'timestamps'] pred = predictor.predict( df=x_df, x_timestamp=x_ts, y_timestamp=y_ts, pred_len=int(pred_len), T=float(temperature), top_p=float(top_p), sample_count=1 ) fig = go.Figure() fig.add_trace(go.Scatter( x=x_ts, y=x_df['close'], name="Historical", line=dict(color='steelblue') )) fig.add_trace(go.Scatter( x=y_ts, y=pred['close'], name="Forecast", line=dict(color='orange', dash='dot') )) fig.update_layout( title="Kronos Forecast", xaxis_title="Time", yaxis_title="Price", template="plotly_dark" ) return fig with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 📈 Kronos Financial Forecaster") gr.Markdown("Upload a CSV with columns: `timestamps, open, high, low, close`") with gr.Row(): file = gr.File(label="📂 Upload CSV") with gr.Row(): pred_len = gr.Slider(10, 120, value=30, step=1, label="Forecast Steps") temp = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") btn = gr.Button("🔮 Run Forecast", variant="primary") plot = gr.Plot(label="Forecast Output") btn.click( fn=forecast, inputs=[file, pred_len, temp, top_p], outputs=plot ) demo.launch()