ariesljm commited on
Commit
5f3989e
·
verified ·
1 Parent(s): c281567

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import pandas as pd
4
+ import torch
5
+ import sys
6
+ import os
7
+ from datetime import timedelta
8
+
9
+ # 引入 Kronos
10
+ sys.path.append(os.path.join(os.path.dirname(__file__), "Kronos"))
11
+ from model import Kronos, KronosTokenizer, KronosPredictor
12
+
13
+ app = FastAPI()
14
+
15
+ # 全局加载模型 (启动时加载一次)
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"Loading model on {device}...")
18
+
19
+ tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
20
+ # 如果 HF 免费版 CPU 跑不动 Base,可以改回 small
21
+ model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
22
+ model = model.to(device)
23
+ model.eval()
24
+
25
+ predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
26
+ print("Model loaded!")
27
+
28
+ class CandleData(BaseModel):
29
+ # 接收的数据格式
30
+ pair: str
31
+ data: list[dict] # 包含 open, high, low, close, volume, date 的列表
32
+
33
+ @app.post("/predict")
34
+ def predict(payload: CandleData):
35
+ try:
36
+ # 1. 重建 DataFrame
37
+ df = pd.DataFrame(payload.data)
38
+ if 'date' in df.columns:
39
+ df['timestamps'] = pd.to_datetime(df['date'])
40
+ else:
41
+ # 如果传来的是时间戳
42
+ df['timestamps'] = pd.to_datetime(df['date'])
43
+
44
+ input_cols = ["open", "high", "low", "close", "volume"]
45
+
46
+ # 2. 准备时间戳
47
+ x_timestamp = df["timestamps"]
48
+ pred_steps = 3
49
+ last_time = x_timestamp.iloc[-1]
50
+ y_timestamp = pd.Series([
51
+ last_time + timedelta(minutes=5 * (i + 1))
52
+ for i in range(pred_steps)
53
+ ])
54
+
55
+ # 3. 推理
56
+ with torch.no_grad():
57
+ forecast = predictor.predict(
58
+ df=df[input_cols],
59
+ x_timestamp=x_timestamp,
60
+ y_timestamp=y_timestamp,
61
+ pred_len=pred_steps,
62
+ T=0.8,
63
+ top_p=0.9,
64
+ sample_count=10 # CPU如果太慢,可以把这里改为 3 或 5
65
+ )
66
+
67
+ # 4. 返回结果
68
+ raw_pred = forecast["close"].mean()
69
+ return {"prediction": float(raw_pred)}
70
+
71
+ except Exception as e:
72
+ print(f"Error: {e}")
73
+ raise HTTPException(status_code=500, detail=str(e))