Kronosdemo / app.py
leehao163's picture
Update app.py
707f2c6 verified
import gradio as gr
import torch
import numpy as np
import pandas as pd
# 现在可以正常导入model模块的类了
from model import Kronos, KronosTokenizer, KronosPredictor
from huggingface_hub import snapshot_download
import os
# 模型ID(Hugging Face Hub上的Kronos模型)
MODEL_ID = "NeoQuasar/Kronos-small"
TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
# 设备配置(自动适配Space的CPU/GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载分词器和模型
tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
model = Kronos.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16 if device=="cuda" else torch.float32)
predictor = KronosPredictor(model, tokenizer, device=device)
# 预测函数(对接Gradio界面)
def predict_kronos(csv_file, prediction_length=5):
try:
# 调用predictor的预测方法
predictions = predictor.predict(
csv_data=csv_file,
prediction_length=prediction_length,
num_samples=10
)
# 处理预测结果为DataFrame
pred_df = pd.DataFrame(
predictions,
columns=["pred_open", "pred_high", "pred_low", "pred_close", "pred_volume"],
index=[f"t+{i+1}" for i in range(prediction_length)]
)
# 读取原始数据的最近10条
raw_df = pd.read_csv(csv_file)
history_df = raw_df[["open", "high", "low", "close", "volume"]].tail(10)
return pred_df.round(2).to_string(), history_df.round(2).to_string()
except Exception as e:
return f"预测错误:{str(e)}", ""
# 构建Gradio界面
with gr.Blocks(title="Kronos金融预测") as demo:
gr.Markdown("# 📈 Kronos K线预测模型")
gr.Markdown("上传包含**open/high/low/close/volume**列的CSV文件,获取未来K线预测")
with gr.Row():
csv_input = gr.File(label="上传CSV文件", file_types=[".csv"])
pred_len = gr.Slider(1, 30, 5, label="预测天数", step=1)
with gr.Row():
predict_btn = gr.Button("开始预测", variant="primary")
clear_btn = gr.Button("清空")
with gr.Row():
pred_output = gr.Textbox(label="预测结果", lines=12)
history_output = gr.Textbox(label="最近10条历史数据", lines=12)
# 绑定按钮事件
predict_btn.click(predict_kronos, [csv_input, pred_len], [pred_output, history_output])
clear_btn.click(lambda: ("", ""), outputs=[pred_output, history_output])
# 启动应用(适配Space的端口和地址)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)