kronos / app.py
ariesljm's picture
Create app.py
5f3989e verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pandas as pd
import torch
import sys
import os
from datetime import timedelta
# 引入 Kronos
sys.path.append(os.path.join(os.path.dirname(__file__), "Kronos"))
from model import Kronos, KronosTokenizer, KronosPredictor
app = FastAPI()
# 全局加载模型 (启动时加载一次)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on {device}...")
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
# 如果 HF 免费版 CPU 跑不动 Base,可以改回 small
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
model = model.to(device)
model.eval()
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
print("Model loaded!")
class CandleData(BaseModel):
# 接收的数据格式
pair: str
data: list[dict] # 包含 open, high, low, close, volume, date 的列表
@app.post("/predict")
def predict(payload: CandleData):
try:
# 1. 重建 DataFrame
df = pd.DataFrame(payload.data)
if 'date' in df.columns:
df['timestamps'] = pd.to_datetime(df['date'])
else:
# 如果传来的是时间戳
df['timestamps'] = pd.to_datetime(df['date'])
input_cols = ["open", "high", "low", "close", "volume"]
# 2. 准备时间戳
x_timestamp = df["timestamps"]
pred_steps = 3
last_time = x_timestamp.iloc[-1]
y_timestamp = pd.Series([
last_time + timedelta(minutes=5 * (i + 1))
for i in range(pred_steps)
])
# 3. 推理
with torch.no_grad():
forecast = predictor.predict(
df=df[input_cols],
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_steps,
T=0.8,
top_p=0.9,
sample_count=10 # CPU如果太慢,可以把这里改为 3 或 5
)
# 4. 返回结果
raw_pred = forecast["close"].mean()
return {"prediction": float(raw_pred)}
except Exception as e:
print(f"Error: {e}")
raise HTTPException(status_code=500, detail=str(e))