Nashid-Noor
Initial commit for HF Spaces without binaries
992aa4f
"""
FastAPI prediction service for the Grid Risk Platform.
Usage
-----
uvicorn src.api:app --host 0.0.0.0 --port 8000
Endpoints
POST /predict β€” score a single outage record
GET /health β€” liveness check
GET /model-info β€” model version and feature list
"""
from __future__ import annotations
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from src.config import ARTIFACTS_DIR, FEATURE_NAMES_FILE, METRICS_FILE, MODEL_VERSION
from src.predict import GridRiskPredictor
logger = logging.getLogger(__name__)
_predictor: Optional[GridRiskPredictor] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global _predictor
try:
_predictor = GridRiskPredictor()
logger.info("Model loaded β€” version %s", MODEL_VERSION)
except FileNotFoundError as e:
logger.error("Artifacts missing β€” run training pipeline first: %s", e)
yield
_predictor = None
app = FastAPI(
title="Grid Risk & Reliability API",
version=MODEL_VERSION,
description="Predict probability of high-impact power outage events.",
lifespan=lifespan,
)
# ── Request / Response schemas ────────────────────────────────────────────
class OutageInput(BaseModel):
"""Input features for a single outage risk prediction.
All fields are optional; the preprocessor handles missing values.
Provide as many as available for best accuracy.
"""
ANOMALY_LEVEL: Optional[float] = Field(None, description="Oceanic NiΓ±o Index anomaly level")
DEMAND_LOSS_MW: Optional[float] = Field(None, ge=0, description="Demand loss in MW")
RES_PRICE: Optional[float] = Field(None, ge=0, description="Residential electricity price (cents/kWh)")
COM_PRICE: Optional[float] = Field(None, ge=0, description="Commercial electricity price")
IND_PRICE: Optional[float] = Field(None, ge=0, description="Industrial electricity price")
TOTAL_PRICE: Optional[float] = Field(None, ge=0, description="Average total electricity price")
TOTAL_SALES: Optional[float] = Field(None, ge=0, description="Total electricity sales (MWh)")
TOTAL_CUSTOMERS: Optional[float] = Field(None, ge=0, description="Total utility customers")
PC_REALGSP_STATE: Optional[float] = Field(None, description="Per capita real GSP of the state")
PC_REALGSP_REL: Optional[float] = Field(None, description="Relative per capita real GSP")
PC_REALGSP_CHANGE: Optional[float] = Field(None, description="% change in per capita real GSP")
UTIL_REALGSP: Optional[float] = Field(None, description="Real GSP contributed by utility sector")
UTIL_CONTRI: Optional[float] = Field(None, description="Utility sector contribution (%)")
POPULATION: Optional[float] = Field(None, ge=0, description="State population")
POPPCT_URBAN: Optional[float] = Field(None, ge=0, le=100, description="Urban population %")
POPDEN_URBAN: Optional[float] = Field(None, ge=0, description="Urban population density")
POPDEN_RURAL: Optional[float] = Field(None, ge=0, description="Rural population density")
AREAPCT_URBAN: Optional[float] = Field(None, ge=0, le=100, description="Urban area %")
PCT_LAND: Optional[float] = Field(None, ge=0, le=100, description="Land area %")
PCT_WATER_TOT: Optional[float] = Field(None, ge=0, le=100, description="Water area %")
CLIMATE_REGION: Optional[str] = Field(None, description="U.S. climate region")
CLIMATE_CATEGORY: Optional[str] = Field(None, description="Climate episode category (warm/cold/normal)")
CAUSE_CATEGORY: Optional[str] = Field(None, description="Outage cause category")
NERC_REGION: Optional[str] = Field(None, description="NERC reliability region")
MONTH: Optional[int] = Field(None, ge=1, le=12, description="Month of event (1-12)")
RES_SALES: Optional[float] = Field(None, ge=0, description="Residential electricity sales (MWh)")
class Config:
json_schema_extra = {
"example": {
"ANOMALY_LEVEL": -0.3,
"DEMAND_LOSS_MW": 250.0,
"RES_PRICE": 11.6,
"COM_PRICE": 9.5,
"IND_PRICE": 6.7,
"TOTAL_PRICE": 9.3,
"TOTAL_CUSTOMERS": 2500000,
"POPULATION": 5800000,
"POPPCT_URBAN": 73.0,
"CLIMATE_REGION": "East North Central",
"CAUSE_CATEGORY": "severe weather",
"NERC_REGION": "RFC",
"MONTH": 7,
}
}
class PredictionResponse(BaseModel):
probability: float = Field(description="P(high-impact outage)")
prediction: int = Field(description="Binary label (0 or 1)")
risk_tier: str = Field(description="LOW / MODERATE / HIGH / CRITICAL")
model_version: str
# ── Helpers ───────────────────────────────────────────────────────────────
def _input_to_record(inp: OutageInput) -> Dict[str, Any]:
"""Convert Pydantic model to the dict format expected by the predictor."""
raw = inp.model_dump(exclude_none=True)
# Map underscored API field names back to dotted dataset column names
mapped: Dict[str, Any] = {}
for key, val in raw.items():
col_name = key.replace("_", ".")
mapped[col_name] = val
return mapped
# ── Endpoints ─────────────────────────────────────────────────────────────
@app.get("/health")
async def health() -> Dict[str, str]:
status = "ok" if _predictor is not None else "model_not_loaded"
return {"status": status, "version": MODEL_VERSION}
@app.get("/model-info")
async def model_info() -> Dict[str, Any]:
if _predictor is None:
raise HTTPException(503, "Model not loaded β€” run training pipeline first.")
metrics_path = ARTIFACTS_DIR / METRICS_FILE
metrics = {}
if metrics_path.exists():
with open(metrics_path) as f:
metrics = json.load(f)
return {
"version": MODEL_VERSION,
"features": _predictor.feature_names,
"metrics": metrics,
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(inp: OutageInput) -> PredictionResponse:
if _predictor is None:
raise HTTPException(503, "Model not loaded β€” run training pipeline first.")
try:
record = _input_to_record(inp)
result = _predictor.predict_single(record)
except Exception as e:
logger.exception("Prediction failed")
raise HTTPException(422, f"Prediction error: {e}")
return PredictionResponse(
probability=round(result["probability"], 4),
prediction=result["prediction"],
risk_tier=result["risk_tier"],
model_version=MODEL_VERSION,
)