Spaces:
Running on Zero
Running on Zero
| import os | |
| import subprocess | |
| # Download Kronos model code at startup | |
| if not os.path.exists("model"): | |
| subprocess.run([ | |
| "git", "clone", "--depth=1", | |
| "https://github.com/shiyu-coder/Kronos", | |
| "kronos_tmp" | |
| ]) | |
| os.rename("kronos_tmp/model", "model") | |
| subprocess.run(["rm", "-rf", "kronos_tmp"]) | |
| import spaces | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from model import Kronos, KronosTokenizer, KronosPredictor | |
| tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") | |
| model = Kronos.from_pretrained("NeoQuasar/Kronos-base") | |
| predictor = KronosPredictor(model, tokenizer, device="cuda", max_context=512) | |
| def forecast(file, pred_len, temperature, top_p): | |
| df = pd.read_csv(file.name) | |
| df['timestamps'] = pd.to_datetime(df['timestamps']) | |
| lookback = min(400, len(df) - int(pred_len)) | |
| x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']] | |
| x_ts = df.loc[:lookback-1, 'timestamps'] | |
| y_ts = df.loc[lookback:lookback+int(pred_len)-1, 'timestamps'] | |
| pred = predictor.predict( | |
| df=x_df, | |
| x_timestamp=x_ts, | |
| y_timestamp=y_ts, | |
| pred_len=int(pred_len), | |
| T=float(temperature), | |
| top_p=float(top_p), | |
| sample_count=1 | |
| ) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=x_ts, y=x_df['close'], | |
| name="Historical", line=dict(color='steelblue') | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=y_ts, y=pred['close'], | |
| name="Forecast", line=dict(color='orange', dash='dot') | |
| )) | |
| fig.update_layout( | |
| title="Kronos Forecast", | |
| xaxis_title="Time", | |
| yaxis_title="Price", | |
| template="plotly_dark" | |
| ) | |
| return fig | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ๐ Kronos Financial Forecaster") | |
| gr.Markdown("Upload a CSV with columns: `timestamps, open, high, low, close`") | |
| with gr.Row(): | |
| file = gr.File(label="๐ Upload CSV") | |
| with gr.Row(): | |
| pred_len = gr.Slider(10, 120, value=30, step=1, label="Forecast Steps") | |
| temp = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| btn = gr.Button("๐ฎ Run Forecast", variant="primary") | |
| plot = gr.Plot(label="Forecast Output") | |
| btn.click( | |
| fn=forecast, | |
| inputs=[file, pred_len, temp, top_p], | |
| outputs=plot | |
| ) | |
| demo.launch() | |