Spaces:
Running
Running
| """ | |
| Chronos-2 Zero-Shot Demo - FastAPI Backend | |
| Standalone version for HF Spaces deployment. | |
| Run locally: uvicorn server:app --reload --port 7860 | |
| """ | |
| from __future__ import annotations | |
| from pathlib import Path | |
| import pandas as pd | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel, Field | |
| # Chronos-2 imports | |
| try: | |
| from chronos import Chronos2Pipeline | |
| except ImportError: | |
| raise ImportError( | |
| "Please install chronos-forecasting>=2.0: pip install 'chronos-forecasting[scripts]>=2.0'" | |
| ) | |
| DEMO_DIR = Path(__file__).resolve().parent | |
| STATIC_DIR = DEMO_DIR / "static" | |
| # Model configuration | |
| MODEL_NAME = "amazon/chronos-2" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ============================================================================= | |
| # Chronos-2 Forecaster (standalone) | |
| # ============================================================================= | |
| class Chronos2Forecaster: | |
| """Wrapper for Chronos-2 time series forecasting.""" | |
| def __init__(self, model_name: str = MODEL_NAME, device: str = DEVICE): | |
| self.model_name = model_name | |
| self.device = device | |
| self.pipeline = None | |
| def load_model(self) -> None: | |
| """Load the Chronos-2 model pipeline.""" | |
| print(f"Loading Chronos-2 model: {self.model_name}") | |
| print(f"Device: {self.device}") | |
| self.pipeline = Chronos2Pipeline.from_pretrained( | |
| self.model_name, | |
| device_map=self.device, | |
| ) | |
| print("Model loaded successfully!") | |
| def forecast( | |
| self, | |
| context_df: pd.DataFrame, | |
| prediction_length: int = 12, | |
| quantile_levels: list[float] | None = None, | |
| ) -> dict: | |
| """Generate probabilistic forecasts.""" | |
| if self.pipeline is None: | |
| self.load_model() | |
| if quantile_levels is None: | |
| quantile_levels = [0.1, 0.5, 0.9] | |
| pred_df = self.pipeline.predict_df( | |
| context_df, | |
| prediction_length=prediction_length, | |
| quantile_levels=quantile_levels, | |
| id_column="item_id", | |
| timestamp_column="timestamp", | |
| target="target", | |
| ) | |
| return { | |
| "median": pred_df["0.5"].values, | |
| "low": pred_df["0.1"].values, | |
| "high": pred_df["0.9"].values, | |
| "pred_df": pred_df, | |
| } | |
| def to_chronos2_context( | |
| df: pd.DataFrame, | |
| target_col: str = "sale_qty", | |
| item_id: str = "gfk_sales", | |
| ) -> pd.DataFrame: | |
| """Convert DataFrame to Chronos-2 long-format context.""" | |
| context = df[["period", target_col]].copy() | |
| context = context.rename(columns={"period": "timestamp", target_col: "target"}) | |
| context["item_id"] = item_id | |
| return context[["item_id", "timestamp", "target"]] | |
| # ============================================================================= | |
| # FastAPI App | |
| # ============================================================================= | |
| _forecaster: Chronos2Forecaster | None = None | |
| def get_forecaster() -> Chronos2Forecaster: | |
| global _forecaster | |
| if _forecaster is None: | |
| _forecaster = Chronos2Forecaster() | |
| _forecaster.load_model() | |
| return _forecaster | |
| class ForecastRequest(BaseModel): | |
| values: list[float] = Field(..., description="Time series values") | |
| prediction_length: int = Field(1, ge=1, le=24, description="Steps to forecast") | |
| class ForecastPoint(BaseModel): | |
| index: int | |
| median: float | |
| low: float | |
| high: float | |
| class ForecastResponse(BaseModel): | |
| historical: list[dict] | |
| forecast: list[ForecastPoint] | |
| def values_to_context_df(values: list[float]) -> pd.DataFrame: | |
| if not values: | |
| raise ValueError("values cannot be empty") | |
| n = len(values) | |
| periods = pd.date_range(start="2020-01-01", periods=n, freq="MS") | |
| df = pd.DataFrame({"period": periods, "sale_qty": values}) | |
| return df | |
| app = FastAPI(title="Chronos-2 Zero-Shot Demo", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def forecast(req: ForecastRequest) -> ForecastResponse: | |
| if not req.values: | |
| raise HTTPException(status_code=400, detail="values cannot be empty") | |
| try: | |
| df = values_to_context_df(req.values) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| context_df = to_chronos2_context(df, target_col="sale_qty", item_id="ts1") | |
| forecaster = get_forecaster() | |
| result = forecaster.forecast(context_df=context_df, prediction_length=req.prediction_length) | |
| historical = [{"index": i, "value": float(v)} for i, v in enumerate(req.values)] | |
| forecast_points = [ | |
| ForecastPoint( | |
| index=len(req.values) + i, | |
| median=float(result["median"][i]), | |
| low=float(result["low"][i]), | |
| high=float(result["high"][i]), | |
| ) | |
| for i in range(req.prediction_length) | |
| ] | |
| return ForecastResponse(historical=historical, forecast=forecast_points) | |
| def index(): | |
| index_path = STATIC_DIR / "index.html" | |
| if not index_path.exists(): | |
| raise HTTPException(status_code=404, detail="index.html not found") | |
| return FileResponse(index_path) | |
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") | |