import gradio as gr import torch import numpy as np import pandas as pd # 现在可以正常导入model模块的类了 from model import Kronos, KronosTokenizer, KronosPredictor from huggingface_hub import snapshot_download import os # 模型ID(Hugging Face Hub上的Kronos模型) MODEL_ID = "NeoQuasar/Kronos-small" TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base" # 设备配置(自动适配Space的CPU/GPU) device = "cuda" if torch.cuda.is_available() else "cpu" # 加载分词器和模型 tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID) model = Kronos.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16 if device=="cuda" else torch.float32) predictor = KronosPredictor(model, tokenizer, device=device) # 预测函数(对接Gradio界面) def predict_kronos(csv_file, prediction_length=5): try: # 调用predictor的预测方法 predictions = predictor.predict( csv_data=csv_file, prediction_length=prediction_length, num_samples=10 ) # 处理预测结果为DataFrame pred_df = pd.DataFrame( predictions, columns=["pred_open", "pred_high", "pred_low", "pred_close", "pred_volume"], index=[f"t+{i+1}" for i in range(prediction_length)] ) # 读取原始数据的最近10条 raw_df = pd.read_csv(csv_file) history_df = raw_df[["open", "high", "low", "close", "volume"]].tail(10) return pred_df.round(2).to_string(), history_df.round(2).to_string() except Exception as e: return f"预测错误:{str(e)}", "" # 构建Gradio界面 with gr.Blocks(title="Kronos金融预测") as demo: gr.Markdown("# 📈 Kronos K线预测模型") gr.Markdown("上传包含**open/high/low/close/volume**列的CSV文件,获取未来K线预测") with gr.Row(): csv_input = gr.File(label="上传CSV文件", file_types=[".csv"]) pred_len = gr.Slider(1, 30, 5, label="预测天数", step=1) with gr.Row(): predict_btn = gr.Button("开始预测", variant="primary") clear_btn = gr.Button("清空") with gr.Row(): pred_output = gr.Textbox(label="预测结果", lines=12) history_output = gr.Textbox(label="最近10条历史数据", lines=12) # 绑定按钮事件 predict_btn.click(predict_kronos, [csv_input, pred_len], [pred_output, history_output]) clear_btn.click(lambda: ("", ""), outputs=[pred_output, history_output]) # 启动应用(适配Space的端口和地址) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)