""" Chronos-2 Zero-Shot Demo - FastAPI Backend Standalone version for HF Spaces deployment. Run locally: uvicorn server:app --reload --port 7860 """ from __future__ import annotations from pathlib import Path import pandas as pd import torch from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from pydantic import BaseModel, Field # Chronos-2 imports try: from chronos import Chronos2Pipeline except ImportError: raise ImportError( "Please install chronos-forecasting>=2.0: pip install 'chronos-forecasting[scripts]>=2.0'" ) DEMO_DIR = Path(__file__).resolve().parent STATIC_DIR = DEMO_DIR / "static" # Model configuration MODEL_NAME = "amazon/chronos-2" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ============================================================================= # Chronos-2 Forecaster (standalone) # ============================================================================= class Chronos2Forecaster: """Wrapper for Chronos-2 time series forecasting.""" def __init__(self, model_name: str = MODEL_NAME, device: str = DEVICE): self.model_name = model_name self.device = device self.pipeline = None def load_model(self) -> None: """Load the Chronos-2 model pipeline.""" print(f"Loading Chronos-2 model: {self.model_name}") print(f"Device: {self.device}") self.pipeline = Chronos2Pipeline.from_pretrained( self.model_name, device_map=self.device, ) print("Model loaded successfully!") def forecast( self, context_df: pd.DataFrame, prediction_length: int = 12, quantile_levels: list[float] | None = None, ) -> dict: """Generate probabilistic forecasts.""" if self.pipeline is None: self.load_model() if quantile_levels is None: quantile_levels = [0.1, 0.5, 0.9] pred_df = self.pipeline.predict_df( context_df, prediction_length=prediction_length, quantile_levels=quantile_levels, id_column="item_id", timestamp_column="timestamp", target="target", ) return { "median": pred_df["0.5"].values, "low": pred_df["0.1"].values, "high": pred_df["0.9"].values, "pred_df": pred_df, } def to_chronos2_context( df: pd.DataFrame, target_col: str = "sale_qty", item_id: str = "gfk_sales", ) -> pd.DataFrame: """Convert DataFrame to Chronos-2 long-format context.""" context = df[["period", target_col]].copy() context = context.rename(columns={"period": "timestamp", target_col: "target"}) context["item_id"] = item_id return context[["item_id", "timestamp", "target"]] # ============================================================================= # FastAPI App # ============================================================================= _forecaster: Chronos2Forecaster | None = None def get_forecaster() -> Chronos2Forecaster: global _forecaster if _forecaster is None: _forecaster = Chronos2Forecaster() _forecaster.load_model() return _forecaster class ForecastRequest(BaseModel): values: list[float] = Field(..., description="Time series values") prediction_length: int = Field(1, ge=1, le=24, description="Steps to forecast") class ForecastPoint(BaseModel): index: int median: float low: float high: float class ForecastResponse(BaseModel): historical: list[dict] forecast: list[ForecastPoint] def values_to_context_df(values: list[float]) -> pd.DataFrame: if not values: raise ValueError("values cannot be empty") n = len(values) periods = pd.date_range(start="2020-01-01", periods=n, freq="MS") df = pd.DataFrame({"period": periods, "sale_qty": values}) return df app = FastAPI(title="Chronos-2 Zero-Shot Demo", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/api/forecast", response_model=ForecastResponse) def forecast(req: ForecastRequest) -> ForecastResponse: if not req.values: raise HTTPException(status_code=400, detail="values cannot be empty") try: df = values_to_context_df(req.values) except Exception as e: raise HTTPException(status_code=400, detail=str(e)) context_df = to_chronos2_context(df, target_col="sale_qty", item_id="ts1") forecaster = get_forecaster() result = forecaster.forecast(context_df=context_df, prediction_length=req.prediction_length) historical = [{"index": i, "value": float(v)} for i, v in enumerate(req.values)] forecast_points = [ ForecastPoint( index=len(req.values) + i, median=float(result["median"][i]), low=float(result["low"][i]), high=float(result["high"][i]), ) for i in range(req.prediction_length) ] return ForecastResponse(historical=historical, forecast=forecast_points) @app.get("/") def index(): index_path = STATIC_DIR / "index.html" if not index_path.exists(): raise HTTPException(status_code=404, detail="index.html not found") return FileResponse(index_path) app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")