Spaces:
Sleeping
Sleeping
File size: 7,143 Bytes
992aa4f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """
FastAPI prediction service for the Grid Risk Platform.
Usage
-----
uvicorn src.api:app --host 0.0.0.0 --port 8000
Endpoints
POST /predict β score a single outage record
GET /health β liveness check
GET /model-info β model version and feature list
"""
from __future__ import annotations
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from src.config import ARTIFACTS_DIR, FEATURE_NAMES_FILE, METRICS_FILE, MODEL_VERSION
from src.predict import GridRiskPredictor
logger = logging.getLogger(__name__)
_predictor: Optional[GridRiskPredictor] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global _predictor
try:
_predictor = GridRiskPredictor()
logger.info("Model loaded β version %s", MODEL_VERSION)
except FileNotFoundError as e:
logger.error("Artifacts missing β run training pipeline first: %s", e)
yield
_predictor = None
app = FastAPI(
title="Grid Risk & Reliability API",
version=MODEL_VERSION,
description="Predict probability of high-impact power outage events.",
lifespan=lifespan,
)
# ββ Request / Response schemas ββββββββββββββββββββββββββββββββββββββββββββ
class OutageInput(BaseModel):
"""Input features for a single outage risk prediction.
All fields are optional; the preprocessor handles missing values.
Provide as many as available for best accuracy.
"""
ANOMALY_LEVEL: Optional[float] = Field(None, description="Oceanic NiΓ±o Index anomaly level")
DEMAND_LOSS_MW: Optional[float] = Field(None, ge=0, description="Demand loss in MW")
RES_PRICE: Optional[float] = Field(None, ge=0, description="Residential electricity price (cents/kWh)")
COM_PRICE: Optional[float] = Field(None, ge=0, description="Commercial electricity price")
IND_PRICE: Optional[float] = Field(None, ge=0, description="Industrial electricity price")
TOTAL_PRICE: Optional[float] = Field(None, ge=0, description="Average total electricity price")
TOTAL_SALES: Optional[float] = Field(None, ge=0, description="Total electricity sales (MWh)")
TOTAL_CUSTOMERS: Optional[float] = Field(None, ge=0, description="Total utility customers")
PC_REALGSP_STATE: Optional[float] = Field(None, description="Per capita real GSP of the state")
PC_REALGSP_REL: Optional[float] = Field(None, description="Relative per capita real GSP")
PC_REALGSP_CHANGE: Optional[float] = Field(None, description="% change in per capita real GSP")
UTIL_REALGSP: Optional[float] = Field(None, description="Real GSP contributed by utility sector")
UTIL_CONTRI: Optional[float] = Field(None, description="Utility sector contribution (%)")
POPULATION: Optional[float] = Field(None, ge=0, description="State population")
POPPCT_URBAN: Optional[float] = Field(None, ge=0, le=100, description="Urban population %")
POPDEN_URBAN: Optional[float] = Field(None, ge=0, description="Urban population density")
POPDEN_RURAL: Optional[float] = Field(None, ge=0, description="Rural population density")
AREAPCT_URBAN: Optional[float] = Field(None, ge=0, le=100, description="Urban area %")
PCT_LAND: Optional[float] = Field(None, ge=0, le=100, description="Land area %")
PCT_WATER_TOT: Optional[float] = Field(None, ge=0, le=100, description="Water area %")
CLIMATE_REGION: Optional[str] = Field(None, description="U.S. climate region")
CLIMATE_CATEGORY: Optional[str] = Field(None, description="Climate episode category (warm/cold/normal)")
CAUSE_CATEGORY: Optional[str] = Field(None, description="Outage cause category")
NERC_REGION: Optional[str] = Field(None, description="NERC reliability region")
MONTH: Optional[int] = Field(None, ge=1, le=12, description="Month of event (1-12)")
RES_SALES: Optional[float] = Field(None, ge=0, description="Residential electricity sales (MWh)")
class Config:
json_schema_extra = {
"example": {
"ANOMALY_LEVEL": -0.3,
"DEMAND_LOSS_MW": 250.0,
"RES_PRICE": 11.6,
"COM_PRICE": 9.5,
"IND_PRICE": 6.7,
"TOTAL_PRICE": 9.3,
"TOTAL_CUSTOMERS": 2500000,
"POPULATION": 5800000,
"POPPCT_URBAN": 73.0,
"CLIMATE_REGION": "East North Central",
"CAUSE_CATEGORY": "severe weather",
"NERC_REGION": "RFC",
"MONTH": 7,
}
}
class PredictionResponse(BaseModel):
probability: float = Field(description="P(high-impact outage)")
prediction: int = Field(description="Binary label (0 or 1)")
risk_tier: str = Field(description="LOW / MODERATE / HIGH / CRITICAL")
model_version: str
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _input_to_record(inp: OutageInput) -> Dict[str, Any]:
"""Convert Pydantic model to the dict format expected by the predictor."""
raw = inp.model_dump(exclude_none=True)
# Map underscored API field names back to dotted dataset column names
mapped: Dict[str, Any] = {}
for key, val in raw.items():
col_name = key.replace("_", ".")
mapped[col_name] = val
return mapped
# ββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/health")
async def health() -> Dict[str, str]:
status = "ok" if _predictor is not None else "model_not_loaded"
return {"status": status, "version": MODEL_VERSION}
@app.get("/model-info")
async def model_info() -> Dict[str, Any]:
if _predictor is None:
raise HTTPException(503, "Model not loaded β run training pipeline first.")
metrics_path = ARTIFACTS_DIR / METRICS_FILE
metrics = {}
if metrics_path.exists():
with open(metrics_path) as f:
metrics = json.load(f)
return {
"version": MODEL_VERSION,
"features": _predictor.feature_names,
"metrics": metrics,
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(inp: OutageInput) -> PredictionResponse:
if _predictor is None:
raise HTTPException(503, "Model not loaded β run training pipeline first.")
try:
record = _input_to_record(inp)
result = _predictor.predict_single(record)
except Exception as e:
logger.exception("Prediction failed")
raise HTTPException(422, f"Prediction error: {e}")
return PredictionResponse(
probability=round(result["probability"], 4),
prediction=result["prediction"],
risk_tier=result["risk_tier"],
model_version=MODEL_VERSION,
)
|