Kronos_Model / app.py
muhamm3dahmed's picture
Rename App.py to app.py
ea054c7 verified
raw
history blame
2.47 kB
import os
import subprocess
# Download Kronos model code at startup
if not os.path.exists("model"):
subprocess.run([
"git", "clone", "--depth=1",
"https://github.com/shiyu-coder/Kronos",
"kronos_tmp"
])
os.rename("kronos_tmp/model", "model")
subprocess.run(["rm", "-rf", "kronos_tmp"])
import spaces
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from model import Kronos, KronosTokenizer, KronosPredictor
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
predictor = KronosPredictor(model, tokenizer, device="cuda", max_context=512)
@spaces.GPU
def forecast(file, pred_len, temperature, top_p):
df = pd.read_csv(file.name)
df['timestamps'] = pd.to_datetime(df['timestamps'])
lookback = min(400, len(df) - int(pred_len))
x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']]
x_ts = df.loc[:lookback-1, 'timestamps']
y_ts = df.loc[lookback:lookback+int(pred_len)-1, 'timestamps']
pred = predictor.predict(
df=x_df,
x_timestamp=x_ts,
y_timestamp=y_ts,
pred_len=int(pred_len),
T=float(temperature),
top_p=float(top_p),
sample_count=1
)
fig = go.Figure()
fig.add_trace(go.Scatter(
x=x_ts, y=x_df['close'],
name="Historical", line=dict(color='steelblue')
))
fig.add_trace(go.Scatter(
x=y_ts, y=pred['close'],
name="Forecast", line=dict(color='orange', dash='dot')
))
fig.update_layout(
title="Kronos Forecast",
xaxis_title="Time",
yaxis_title="Price",
template="plotly_dark"
)
return fig
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐Ÿ“ˆ Kronos Financial Forecaster")
gr.Markdown("Upload a CSV with columns: `timestamps, open, high, low, close`")
with gr.Row():
file = gr.File(label="๐Ÿ“‚ Upload CSV")
with gr.Row():
pred_len = gr.Slider(10, 120, value=30, step=1, label="Forecast Steps")
temp = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
btn = gr.Button("๐Ÿ”ฎ Run Forecast", variant="primary")
plot = gr.Plot(label="Forecast Output")
btn.click(
fn=forecast,
inputs=[file, pred_len, temp, top_p],
outputs=plot
)
demo.launch()