File size: 7,143 Bytes
992aa4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
FastAPI prediction service for the Grid Risk Platform.

Usage
-----
    uvicorn src.api:app --host 0.0.0.0 --port 8000

Endpoints
    POST /predict    β€” score a single outage record
    GET  /health     β€” liveness check
    GET  /model-info β€” model version and feature list
"""
from __future__ import annotations

import json
import logging
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

from src.config import ARTIFACTS_DIR, FEATURE_NAMES_FILE, METRICS_FILE, MODEL_VERSION
from src.predict import GridRiskPredictor

logger = logging.getLogger(__name__)

_predictor: Optional[GridRiskPredictor] = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global _predictor
    try:
        _predictor = GridRiskPredictor()
        logger.info("Model loaded β€” version %s", MODEL_VERSION)
    except FileNotFoundError as e:
        logger.error("Artifacts missing β€” run training pipeline first: %s", e)
    yield
    _predictor = None


app = FastAPI(
    title="Grid Risk & Reliability API",
    version=MODEL_VERSION,
    description="Predict probability of high-impact power outage events.",
    lifespan=lifespan,
)


# ── Request / Response schemas ────────────────────────────────────────────

class OutageInput(BaseModel):
    """Input features for a single outage risk prediction.

    All fields are optional; the preprocessor handles missing values.
    Provide as many as available for best accuracy.
    """
    ANOMALY_LEVEL: Optional[float] = Field(None, description="Oceanic NiΓ±o Index anomaly level")
    DEMAND_LOSS_MW: Optional[float] = Field(None, ge=0, description="Demand loss in MW")
    RES_PRICE: Optional[float] = Field(None, ge=0, description="Residential electricity price (cents/kWh)")
    COM_PRICE: Optional[float] = Field(None, ge=0, description="Commercial electricity price")
    IND_PRICE: Optional[float] = Field(None, ge=0, description="Industrial electricity price")
    TOTAL_PRICE: Optional[float] = Field(None, ge=0, description="Average total electricity price")
    TOTAL_SALES: Optional[float] = Field(None, ge=0, description="Total electricity sales (MWh)")
    TOTAL_CUSTOMERS: Optional[float] = Field(None, ge=0, description="Total utility customers")
    PC_REALGSP_STATE: Optional[float] = Field(None, description="Per capita real GSP of the state")
    PC_REALGSP_REL: Optional[float] = Field(None, description="Relative per capita real GSP")
    PC_REALGSP_CHANGE: Optional[float] = Field(None, description="% change in per capita real GSP")
    UTIL_REALGSP: Optional[float] = Field(None, description="Real GSP contributed by utility sector")
    UTIL_CONTRI: Optional[float] = Field(None, description="Utility sector contribution (%)")
    POPULATION: Optional[float] = Field(None, ge=0, description="State population")
    POPPCT_URBAN: Optional[float] = Field(None, ge=0, le=100, description="Urban population %")
    POPDEN_URBAN: Optional[float] = Field(None, ge=0, description="Urban population density")
    POPDEN_RURAL: Optional[float] = Field(None, ge=0, description="Rural population density")
    AREAPCT_URBAN: Optional[float] = Field(None, ge=0, le=100, description="Urban area %")
    PCT_LAND: Optional[float] = Field(None, ge=0, le=100, description="Land area %")
    PCT_WATER_TOT: Optional[float] = Field(None, ge=0, le=100, description="Water area %")
    CLIMATE_REGION: Optional[str] = Field(None, description="U.S. climate region")
    CLIMATE_CATEGORY: Optional[str] = Field(None, description="Climate episode category (warm/cold/normal)")
    CAUSE_CATEGORY: Optional[str] = Field(None, description="Outage cause category")
    NERC_REGION: Optional[str] = Field(None, description="NERC reliability region")
    MONTH: Optional[int] = Field(None, ge=1, le=12, description="Month of event (1-12)")
    RES_SALES: Optional[float] = Field(None, ge=0, description="Residential electricity sales (MWh)")

    class Config:
        json_schema_extra = {
            "example": {
                "ANOMALY_LEVEL": -0.3,
                "DEMAND_LOSS_MW": 250.0,
                "RES_PRICE": 11.6,
                "COM_PRICE": 9.5,
                "IND_PRICE": 6.7,
                "TOTAL_PRICE": 9.3,
                "TOTAL_CUSTOMERS": 2500000,
                "POPULATION": 5800000,
                "POPPCT_URBAN": 73.0,
                "CLIMATE_REGION": "East North Central",
                "CAUSE_CATEGORY": "severe weather",
                "NERC_REGION": "RFC",
                "MONTH": 7,
            }
        }


class PredictionResponse(BaseModel):
    probability: float = Field(description="P(high-impact outage)")
    prediction: int = Field(description="Binary label (0 or 1)")
    risk_tier: str = Field(description="LOW / MODERATE / HIGH / CRITICAL")
    model_version: str


# ── Helpers ───────────────────────────────────────────────────────────────

def _input_to_record(inp: OutageInput) -> Dict[str, Any]:
    """Convert Pydantic model to the dict format expected by the predictor."""
    raw = inp.model_dump(exclude_none=True)
    # Map underscored API field names back to dotted dataset column names
    mapped: Dict[str, Any] = {}
    for key, val in raw.items():
        col_name = key.replace("_", ".")
        mapped[col_name] = val
    return mapped


# ── Endpoints ─────────────────────────────────────────────────────────────

@app.get("/health")
async def health() -> Dict[str, str]:
    status = "ok" if _predictor is not None else "model_not_loaded"
    return {"status": status, "version": MODEL_VERSION}


@app.get("/model-info")
async def model_info() -> Dict[str, Any]:
    if _predictor is None:
        raise HTTPException(503, "Model not loaded β€” run training pipeline first.")
    metrics_path = ARTIFACTS_DIR / METRICS_FILE
    metrics = {}
    if metrics_path.exists():
        with open(metrics_path) as f:
            metrics = json.load(f)
    return {
        "version": MODEL_VERSION,
        "features": _predictor.feature_names,
        "metrics": metrics,
    }


@app.post("/predict", response_model=PredictionResponse)
async def predict(inp: OutageInput) -> PredictionResponse:
    if _predictor is None:
        raise HTTPException(503, "Model not loaded β€” run training pipeline first.")

    try:
        record = _input_to_record(inp)
        result = _predictor.predict_single(record)
    except Exception as e:
        logger.exception("Prediction failed")
        raise HTTPException(422, f"Prediction error: {e}")

    return PredictionResponse(
        probability=round(result["probability"], 4),
        prediction=result["prediction"],
        risk_tier=result["risk_tier"],
        model_version=MODEL_VERSION,
    )