Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| # 现在可以正常导入model模块的类了 | |
| from model import Kronos, KronosTokenizer, KronosPredictor | |
| from huggingface_hub import snapshot_download | |
| import os | |
| # 模型ID(Hugging Face Hub上的Kronos模型) | |
| MODEL_ID = "NeoQuasar/Kronos-small" | |
| TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base" | |
| # 设备配置(自动适配Space的CPU/GPU) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 加载分词器和模型 | |
| tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID) | |
| model = Kronos.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16 if device=="cuda" else torch.float32) | |
| predictor = KronosPredictor(model, tokenizer, device=device) | |
| # 预测函数(对接Gradio界面) | |
| def predict_kronos(csv_file, prediction_length=5): | |
| try: | |
| # 调用predictor的预测方法 | |
| predictions = predictor.predict( | |
| csv_data=csv_file, | |
| prediction_length=prediction_length, | |
| num_samples=10 | |
| ) | |
| # 处理预测结果为DataFrame | |
| pred_df = pd.DataFrame( | |
| predictions, | |
| columns=["pred_open", "pred_high", "pred_low", "pred_close", "pred_volume"], | |
| index=[f"t+{i+1}" for i in range(prediction_length)] | |
| ) | |
| # 读取原始数据的最近10条 | |
| raw_df = pd.read_csv(csv_file) | |
| history_df = raw_df[["open", "high", "low", "close", "volume"]].tail(10) | |
| return pred_df.round(2).to_string(), history_df.round(2).to_string() | |
| except Exception as e: | |
| return f"预测错误:{str(e)}", "" | |
| # 构建Gradio界面 | |
| with gr.Blocks(title="Kronos金融预测") as demo: | |
| gr.Markdown("# 📈 Kronos K线预测模型") | |
| gr.Markdown("上传包含**open/high/low/close/volume**列的CSV文件,获取未来K线预测") | |
| with gr.Row(): | |
| csv_input = gr.File(label="上传CSV文件", file_types=[".csv"]) | |
| pred_len = gr.Slider(1, 30, 5, label="预测天数", step=1) | |
| with gr.Row(): | |
| predict_btn = gr.Button("开始预测", variant="primary") | |
| clear_btn = gr.Button("清空") | |
| with gr.Row(): | |
| pred_output = gr.Textbox(label="预测结果", lines=12) | |
| history_output = gr.Textbox(label="最近10条历史数据", lines=12) | |
| # 绑定按钮事件 | |
| predict_btn.click(predict_kronos, [csv_input, pred_len], [pred_output, history_output]) | |
| clear_btn.click(lambda: ("", ""), outputs=[pred_output, history_output]) | |
| # 启动应用(适配Space的端口和地址) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |