jamnif commited on
Commit
826dae2
·
verified ·
1 Parent(s): 5f6fad5

Upload server.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server.py +184 -0
server.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chronos-2 Zero-Shot Demo - FastAPI Backend
3
+
4
+ Standalone version for HF Spaces deployment.
5
+ Run locally: uvicorn server:app --reload --port 7860
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from pathlib import Path
11
+
12
+ import pandas as pd
13
+ import torch
14
+
15
+ from fastapi import FastAPI, HTTPException
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from fastapi.staticfiles import StaticFiles
18
+ from fastapi.responses import FileResponse
19
+ from pydantic import BaseModel, Field
20
+
21
+ # Chronos-2 imports
22
+ try:
23
+ from chronos import Chronos2Pipeline
24
+ except ImportError:
25
+ raise ImportError(
26
+ "Please install chronos-forecasting>=2.0: pip install 'chronos-forecasting[scripts]>=2.0'"
27
+ )
28
+
29
+ DEMO_DIR = Path(__file__).resolve().parent
30
+ STATIC_DIR = DEMO_DIR / "static"
31
+
32
+ # Model configuration
33
+ MODEL_NAME = "amazon/chronos-2"
34
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+
37
+ # =============================================================================
38
+ # Chronos-2 Forecaster (standalone)
39
+ # =============================================================================
40
+
41
+ class Chronos2Forecaster:
42
+ """Wrapper for Chronos-2 time series forecasting."""
43
+
44
+ def __init__(self, model_name: str = MODEL_NAME, device: str = DEVICE):
45
+ self.model_name = model_name
46
+ self.device = device
47
+ self.pipeline = None
48
+
49
+ def load_model(self) -> None:
50
+ """Load the Chronos-2 model pipeline."""
51
+ print(f"Loading Chronos-2 model: {self.model_name}")
52
+ print(f"Device: {self.device}")
53
+ self.pipeline = Chronos2Pipeline.from_pretrained(
54
+ self.model_name,
55
+ device_map=self.device,
56
+ )
57
+ print("Model loaded successfully!")
58
+
59
+ def forecast(
60
+ self,
61
+ context_df: pd.DataFrame,
62
+ prediction_length: int = 12,
63
+ quantile_levels: list[float] | None = None,
64
+ ) -> dict:
65
+ """Generate probabilistic forecasts."""
66
+ if self.pipeline is None:
67
+ self.load_model()
68
+
69
+ if quantile_levels is None:
70
+ quantile_levels = [0.1, 0.5, 0.9]
71
+
72
+ pred_df = self.pipeline.predict_df(
73
+ context_df,
74
+ prediction_length=prediction_length,
75
+ quantile_levels=quantile_levels,
76
+ id_column="item_id",
77
+ timestamp_column="timestamp",
78
+ target="target",
79
+ )
80
+
81
+ return {
82
+ "median": pred_df["0.5"].values,
83
+ "low": pred_df["0.1"].values,
84
+ "high": pred_df["0.9"].values,
85
+ "pred_df": pred_df,
86
+ }
87
+
88
+
89
+ def to_chronos2_context(
90
+ df: pd.DataFrame,
91
+ target_col: str = "sale_qty",
92
+ item_id: str = "gfk_sales",
93
+ ) -> pd.DataFrame:
94
+ """Convert DataFrame to Chronos-2 long-format context."""
95
+ context = df[["period", target_col]].copy()
96
+ context = context.rename(columns={"period": "timestamp", target_col: "target"})
97
+ context["item_id"] = item_id
98
+ return context[["item_id", "timestamp", "target"]]
99
+
100
+
101
+ # =============================================================================
102
+ # FastAPI App
103
+ # =============================================================================
104
+
105
+ _forecaster: Chronos2Forecaster | None = None
106
+
107
+
108
+ def get_forecaster() -> Chronos2Forecaster:
109
+ global _forecaster
110
+ if _forecaster is None:
111
+ _forecaster = Chronos2Forecaster()
112
+ _forecaster.load_model()
113
+ return _forecaster
114
+
115
+
116
+ class ForecastRequest(BaseModel):
117
+ values: list[float] = Field(..., description="Time series values")
118
+ prediction_length: int = Field(1, ge=1, le=24, description="Steps to forecast")
119
+
120
+
121
+ class ForecastPoint(BaseModel):
122
+ index: int
123
+ median: float
124
+ low: float
125
+ high: float
126
+
127
+
128
+ class ForecastResponse(BaseModel):
129
+ historical: list[dict]
130
+ forecast: list[ForecastPoint]
131
+
132
+
133
+ def values_to_context_df(values: list[float]) -> pd.DataFrame:
134
+ if not values:
135
+ raise ValueError("values cannot be empty")
136
+ n = len(values)
137
+ periods = pd.date_range(start="2020-01-01", periods=n, freq="MS")
138
+ df = pd.DataFrame({"period": periods, "sale_qty": values})
139
+ return df
140
+
141
+
142
+ app = FastAPI(title="Chronos-2 Zero-Shot Demo", version="1.0.0")
143
+ app.add_middleware(
144
+ CORSMiddleware,
145
+ allow_origins=["*"],
146
+ allow_credentials=True,
147
+ allow_methods=["*"],
148
+ allow_headers=["*"],
149
+ )
150
+
151
+
152
+ @app.post("/api/forecast", response_model=ForecastResponse)
153
+ def forecast(req: ForecastRequest) -> ForecastResponse:
154
+ if not req.values:
155
+ raise HTTPException(status_code=400, detail="values cannot be empty")
156
+ try:
157
+ df = values_to_context_df(req.values)
158
+ except Exception as e:
159
+ raise HTTPException(status_code=400, detail=str(e))
160
+ context_df = to_chronos2_context(df, target_col="sale_qty", item_id="ts1")
161
+ forecaster = get_forecaster()
162
+ result = forecaster.forecast(context_df=context_df, prediction_length=req.prediction_length)
163
+ historical = [{"index": i, "value": float(v)} for i, v in enumerate(req.values)]
164
+ forecast_points = [
165
+ ForecastPoint(
166
+ index=len(req.values) + i,
167
+ median=float(result["median"][i]),
168
+ low=float(result["low"][i]),
169
+ high=float(result["high"][i]),
170
+ )
171
+ for i in range(req.prediction_length)
172
+ ]
173
+ return ForecastResponse(historical=historical, forecast=forecast_points)
174
+
175
+
176
+ @app.get("/")
177
+ def index():
178
+ index_path = STATIC_DIR / "index.html"
179
+ if not index_path.exists():
180
+ raise HTTPException(status_code=404, detail="index.html not found")
181
+ return FileResponse(index_path)
182
+
183
+
184
+ app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")