File size: 2,292 Bytes
5f3989e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pandas as pd
import torch
import sys
import os
from datetime import timedelta
# 引入 Kronos
sys.path.append(os.path.join(os.path.dirname(__file__), "Kronos"))
from model import Kronos, KronosTokenizer, KronosPredictor
app = FastAPI()
# 全局加载模型 (启动时加载一次)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on {device}...")
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
# 如果 HF 免费版 CPU 跑不动 Base,可以改回 small
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
model = model.to(device)
model.eval()
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
print("Model loaded!")
class CandleData(BaseModel):
# 接收的数据格式
pair: str
data: list[dict] # 包含 open, high, low, close, volume, date 的列表
@app.post("/predict")
def predict(payload: CandleData):
try:
# 1. 重建 DataFrame
df = pd.DataFrame(payload.data)
if 'date' in df.columns:
df['timestamps'] = pd.to_datetime(df['date'])
else:
# 如果传来的是时间戳
df['timestamps'] = pd.to_datetime(df['date'])
input_cols = ["open", "high", "low", "close", "volume"]
# 2. 准备时间戳
x_timestamp = df["timestamps"]
pred_steps = 3
last_time = x_timestamp.iloc[-1]
y_timestamp = pd.Series([
last_time + timedelta(minutes=5 * (i + 1))
for i in range(pred_steps)
])
# 3. 推理
with torch.no_grad():
forecast = predictor.predict(
df=df[input_cols],
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_steps,
T=0.8,
top_p=0.9,
sample_count=10 # CPU如果太慢,可以把这里改为 3 或 5
)
# 4. 返回结果
raw_pred = forecast["close"].mean()
return {"prediction": float(raw_pred)}
except Exception as e:
print(f"Error: {e}")
raise HTTPException(status_code=500, detail=str(e)) |