Taylor1998 commited on
Commit
c81d774
·
verified ·
1 Parent(s): bb0bd37

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import timesfm
4
+
5
+ app = FastAPI()
6
+
7
+ # 1. 初始化并加载 TimesFM 模型 (Space 启动时执行)
8
+ tfm = timesfm.TimesFm(
9
+ context_len=512, # 根据你的 K 线序列长度调整
10
+ horizon_len=24, # 预测步长
11
+ input_patch_len=32,
12
+ output_patch_len=128,
13
+ num_layers=20,
14
+ model_dims=1280,
15
+ backend="cpu" # 免费层使用 CPU
16
+ )
17
+ tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")
18
+
19
+ # 2. 定义 API 数据结构
20
+ class PredictRequest(BaseModel):
21
+ history: list[float]
22
+ horizon: int = 24
23
+
24
+ # 3. 提供预测接口
25
+ @app.post("/predict")
26
+ def predict(req: PredictRequest):
27
+ # 将前端传来的 K 线数据喂给模型
28
+ forecast = tfm.forecast([req.history])
29
+
30
+ return {
31
+ "status": "success",
32
+ "forecast": forecast[0].tolist()
33
+ }