from fastapi import FastAPI, HTTPException from pydantic import BaseModel import pandas as pd import torch import sys import os from datetime import timedelta # 引入 Kronos sys.path.append(os.path.join(os.path.dirname(__file__), "Kronos")) from model import Kronos, KronosTokenizer, KronosPredictor app = FastAPI() # 全局加载模型 (启动时加载一次) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on {device}...") tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") # 如果 HF 免费版 CPU 跑不动 Base,可以改回 small model = Kronos.from_pretrained("NeoQuasar/Kronos-base") model = model.to(device) model.eval() predictor = KronosPredictor(model, tokenizer, device=device, max_context=512) print("Model loaded!") class CandleData(BaseModel): # 接收的数据格式 pair: str data: list[dict] # 包含 open, high, low, close, volume, date 的列表 @app.post("/predict") def predict(payload: CandleData): try: # 1. 重建 DataFrame df = pd.DataFrame(payload.data) if 'date' in df.columns: df['timestamps'] = pd.to_datetime(df['date']) else: # 如果传来的是时间戳 df['timestamps'] = pd.to_datetime(df['date']) input_cols = ["open", "high", "low", "close", "volume"] # 2. 准备时间戳 x_timestamp = df["timestamps"] pred_steps = 3 last_time = x_timestamp.iloc[-1] y_timestamp = pd.Series([ last_time + timedelta(minutes=5 * (i + 1)) for i in range(pred_steps) ]) # 3. 推理 with torch.no_grad(): forecast = predictor.predict( df=df[input_cols], x_timestamp=x_timestamp, y_timestamp=y_timestamp, pred_len=pred_steps, T=0.8, top_p=0.9, sample_count=10 # CPU如果太慢,可以把这里改为 3 或 5 ) # 4. 返回结果 raw_pred = forecast["close"].mean() return {"prediction": float(raw_pred)} except Exception as e: print(f"Error: {e}") raise HTTPException(status_code=500, detail=str(e))