Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI prediction service for the Grid Risk Platform. | |
| Usage | |
| ----- | |
| uvicorn src.api:app --host 0.0.0.0 --port 8000 | |
| Endpoints | |
| POST /predict β score a single outage record | |
| GET /health β liveness check | |
| GET /model-info β model version and feature list | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from typing import Any, Dict, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from src.config import ARTIFACTS_DIR, FEATURE_NAMES_FILE, METRICS_FILE, MODEL_VERSION | |
| from src.predict import GridRiskPredictor | |
| logger = logging.getLogger(__name__) | |
| _predictor: Optional[GridRiskPredictor] = None | |
| async def lifespan(app: FastAPI): | |
| global _predictor | |
| try: | |
| _predictor = GridRiskPredictor() | |
| logger.info("Model loaded β version %s", MODEL_VERSION) | |
| except FileNotFoundError as e: | |
| logger.error("Artifacts missing β run training pipeline first: %s", e) | |
| yield | |
| _predictor = None | |
| app = FastAPI( | |
| title="Grid Risk & Reliability API", | |
| version=MODEL_VERSION, | |
| description="Predict probability of high-impact power outage events.", | |
| lifespan=lifespan, | |
| ) | |
| # ββ Request / Response schemas ββββββββββββββββββββββββββββββββββββββββββββ | |
| class OutageInput(BaseModel): | |
| """Input features for a single outage risk prediction. | |
| All fields are optional; the preprocessor handles missing values. | |
| Provide as many as available for best accuracy. | |
| """ | |
| ANOMALY_LEVEL: Optional[float] = Field(None, description="Oceanic NiΓ±o Index anomaly level") | |
| DEMAND_LOSS_MW: Optional[float] = Field(None, ge=0, description="Demand loss in MW") | |
| RES_PRICE: Optional[float] = Field(None, ge=0, description="Residential electricity price (cents/kWh)") | |
| COM_PRICE: Optional[float] = Field(None, ge=0, description="Commercial electricity price") | |
| IND_PRICE: Optional[float] = Field(None, ge=0, description="Industrial electricity price") | |
| TOTAL_PRICE: Optional[float] = Field(None, ge=0, description="Average total electricity price") | |
| TOTAL_SALES: Optional[float] = Field(None, ge=0, description="Total electricity sales (MWh)") | |
| TOTAL_CUSTOMERS: Optional[float] = Field(None, ge=0, description="Total utility customers") | |
| PC_REALGSP_STATE: Optional[float] = Field(None, description="Per capita real GSP of the state") | |
| PC_REALGSP_REL: Optional[float] = Field(None, description="Relative per capita real GSP") | |
| PC_REALGSP_CHANGE: Optional[float] = Field(None, description="% change in per capita real GSP") | |
| UTIL_REALGSP: Optional[float] = Field(None, description="Real GSP contributed by utility sector") | |
| UTIL_CONTRI: Optional[float] = Field(None, description="Utility sector contribution (%)") | |
| POPULATION: Optional[float] = Field(None, ge=0, description="State population") | |
| POPPCT_URBAN: Optional[float] = Field(None, ge=0, le=100, description="Urban population %") | |
| POPDEN_URBAN: Optional[float] = Field(None, ge=0, description="Urban population density") | |
| POPDEN_RURAL: Optional[float] = Field(None, ge=0, description="Rural population density") | |
| AREAPCT_URBAN: Optional[float] = Field(None, ge=0, le=100, description="Urban area %") | |
| PCT_LAND: Optional[float] = Field(None, ge=0, le=100, description="Land area %") | |
| PCT_WATER_TOT: Optional[float] = Field(None, ge=0, le=100, description="Water area %") | |
| CLIMATE_REGION: Optional[str] = Field(None, description="U.S. climate region") | |
| CLIMATE_CATEGORY: Optional[str] = Field(None, description="Climate episode category (warm/cold/normal)") | |
| CAUSE_CATEGORY: Optional[str] = Field(None, description="Outage cause category") | |
| NERC_REGION: Optional[str] = Field(None, description="NERC reliability region") | |
| MONTH: Optional[int] = Field(None, ge=1, le=12, description="Month of event (1-12)") | |
| RES_SALES: Optional[float] = Field(None, ge=0, description="Residential electricity sales (MWh)") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "ANOMALY_LEVEL": -0.3, | |
| "DEMAND_LOSS_MW": 250.0, | |
| "RES_PRICE": 11.6, | |
| "COM_PRICE": 9.5, | |
| "IND_PRICE": 6.7, | |
| "TOTAL_PRICE": 9.3, | |
| "TOTAL_CUSTOMERS": 2500000, | |
| "POPULATION": 5800000, | |
| "POPPCT_URBAN": 73.0, | |
| "CLIMATE_REGION": "East North Central", | |
| "CAUSE_CATEGORY": "severe weather", | |
| "NERC_REGION": "RFC", | |
| "MONTH": 7, | |
| } | |
| } | |
| class PredictionResponse(BaseModel): | |
| probability: float = Field(description="P(high-impact outage)") | |
| prediction: int = Field(description="Binary label (0 or 1)") | |
| risk_tier: str = Field(description="LOW / MODERATE / HIGH / CRITICAL") | |
| model_version: str | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _input_to_record(inp: OutageInput) -> Dict[str, Any]: | |
| """Convert Pydantic model to the dict format expected by the predictor.""" | |
| raw = inp.model_dump(exclude_none=True) | |
| # Map underscored API field names back to dotted dataset column names | |
| mapped: Dict[str, Any] = {} | |
| for key, val in raw.items(): | |
| col_name = key.replace("_", ".") | |
| mapped[col_name] = val | |
| return mapped | |
| # ββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def health() -> Dict[str, str]: | |
| status = "ok" if _predictor is not None else "model_not_loaded" | |
| return {"status": status, "version": MODEL_VERSION} | |
| async def model_info() -> Dict[str, Any]: | |
| if _predictor is None: | |
| raise HTTPException(503, "Model not loaded β run training pipeline first.") | |
| metrics_path = ARTIFACTS_DIR / METRICS_FILE | |
| metrics = {} | |
| if metrics_path.exists(): | |
| with open(metrics_path) as f: | |
| metrics = json.load(f) | |
| return { | |
| "version": MODEL_VERSION, | |
| "features": _predictor.feature_names, | |
| "metrics": metrics, | |
| } | |
| async def predict(inp: OutageInput) -> PredictionResponse: | |
| if _predictor is None: | |
| raise HTTPException(503, "Model not loaded β run training pipeline first.") | |
| try: | |
| record = _input_to_record(inp) | |
| result = _predictor.predict_single(record) | |
| except Exception as e: | |
| logger.exception("Prediction failed") | |
| raise HTTPException(422, f"Prediction error: {e}") | |
| return PredictionResponse( | |
| probability=round(result["probability"], 4), | |
| prediction=result["prediction"], | |
| risk_tier=result["risk_tier"], | |
| model_version=MODEL_VERSION, | |
| ) | |