Spaces:
Sleeping
Sleeping
| """ | |
| Gradio interface for the Grid Risk & Reliability Platform. | |
| Usage | |
| ----- | |
| python -m src.app # launches on localhost:7860 | |
| GEMINI_API_KEY=... python -m src.app # enables plain-language explanations | |
| Features | |
| -------- | |
| β’ Input form with key outage-related features | |
| β’ Risk score + tier badge | |
| β’ SHAP top contributing factors table | |
| β’ Optional Gemini-powered plain-language summary | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Lazy-loaded singletons (avoids loading model on import) | |
| # --------------------------------------------------------------------------- | |
| def _get_predictor(): | |
| from src.predict import GridRiskPredictor | |
| return GridRiskPredictor() | |
| def _get_feature_names() -> List[str]: | |
| from src.config import ARTIFACTS_DIR, FEATURE_NAMES_FILE | |
| with open(ARTIFACTS_DIR / FEATURE_NAMES_FILE) as f: | |
| return json.load(f) | |
| def _get_gemini_model(): | |
| """Return a Gemini GenerativeModel or None if API key is absent.""" | |
| api_key = os.environ.get("GEMINI_API_KEY") | |
| if not api_key: | |
| return None | |
| try: | |
| import google.generativeai as genai | |
| genai.configure(api_key=api_key) | |
| return genai.GenerativeModel("gemini-2.5-flash") | |
| except Exception as e: | |
| logger.warning("Gemini init failed: %s", e) | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Core logic | |
| # --------------------------------------------------------------------------- | |
| TIER_COLORS = { | |
| "CRITICAL": "π΄", | |
| "HIGH": "π ", | |
| "MODERATE": "π‘", | |
| "LOW": "π’", | |
| } | |
| def _build_record( | |
| anomaly_level: float, | |
| demand_loss_mw: float, | |
| res_price: float, | |
| com_price: float, | |
| ind_price: float, | |
| total_price: float, | |
| total_sales: float, | |
| total_customers: float, | |
| population: float, | |
| poppct_urban: float, | |
| popden_urban: float, | |
| popden_rural: float, | |
| climate_region: str, | |
| climate_category: str, | |
| cause_category: str, | |
| nerc_region: str, | |
| month: int, | |
| ) -> Dict[str, Any]: | |
| """Pack UI inputs into the dict format expected by the predictor.""" | |
| return { | |
| "ANOMALY.LEVEL": anomaly_level, | |
| "DEMAND.LOSS.MW": demand_loss_mw, | |
| "RES.PRICE": res_price, | |
| "COM.PRICE": com_price, | |
| "IND.PRICE": ind_price, | |
| "TOTAL.PRICE": total_price, | |
| "TOTAL.SALES": total_sales, | |
| "TOTAL.CUSTOMERS": total_customers, | |
| "POPULATION": population, | |
| "POPPCT_URBAN": poppct_urban, | |
| "POPDEN_URBAN": popden_urban, | |
| "POPDEN_RURAL": popden_rural, | |
| "CLIMATE.REGION": climate_region, | |
| "CLIMATE.CATEGORY": climate_category, | |
| "CAUSE.CATEGORY": cause_category, | |
| "NERC.REGION": nerc_region, | |
| "MONTH": month, | |
| } | |
| def predict_risk( | |
| anomaly_level, demand_loss_mw, res_price, com_price, ind_price, | |
| total_price, total_sales, total_customers, population, poppct_urban, | |
| popden_urban, popden_rural, climate_region, climate_category, | |
| cause_category, nerc_region, month, | |
| ) -> Tuple[str, str]: | |
| """Run prediction and return (risk_summary, shap_table_markdown).""" | |
| predictor = _get_predictor() | |
| record = _build_record( | |
| anomaly_level, demand_loss_mw, res_price, com_price, ind_price, | |
| total_price, total_sales, total_customers, population, poppct_urban, | |
| popden_urban, popden_rural, climate_region, climate_category, | |
| cause_category, nerc_region, int(month), | |
| ) | |
| result = predictor.predict_single(record) | |
| prob = result["probability"] | |
| tier = result["risk_tier"] | |
| icon = TIER_COLORS.get(tier, "βͺ") | |
| # SHAP explanation | |
| from src.explain import explain_prediction | |
| df = pd.DataFrame([record]) | |
| from src.features import engineer_features | |
| df = engineer_features(df) | |
| # Ensure all columns expected by the preprocessor are present (fill with NaN if missing) | |
| expected_cols = getattr(predictor.preprocessor, "feature_names_in_", []) | |
| for col in expected_cols: | |
| if col not in df.columns: | |
| df[col] = np.nan | |
| X = predictor.preprocessor.transform(df) | |
| shap_factors = explain_prediction( | |
| X, model=predictor.model, feature_names=predictor.feature_names, top_k=8 | |
| ) | |
| # Format outputs | |
| summary = ( | |
| f"## {icon} Risk Tier: **{tier}**\n\n" | |
| f"**Probability of high-impact outage:** {prob:.1%}\n\n" | |
| f"Threshold: β₯50% β positive prediction" | |
| ) | |
| table_rows = ["| Feature | SHAP Value | Direction |", "|---|---|---|"] | |
| for f in shap_factors: | |
| table_rows.append(f"| {f['feature']} | {f['shap_value']:+.4f} | {f['direction']} |") | |
| shap_table = "\n".join(table_rows) | |
| return summary, shap_table | |
| def generate_gemini_explanation( | |
| anomaly_level, demand_loss_mw, res_price, com_price, ind_price, | |
| total_price, total_sales, total_customers, population, poppct_urban, | |
| popden_urban, popden_rural, climate_region, climate_category, | |
| cause_category, nerc_region, month, | |
| ) -> str: | |
| """Call Gemini to produce a plain-language explanation of the risk.""" | |
| model = _get_gemini_model() | |
| if model is None: | |
| return ( | |
| "β οΈ Gemini API key not configured.\n\n" | |
| "Set the `GEMINI_API_KEY` environment variable and restart the app." | |
| ) | |
| # Run the prediction first to get the numbers | |
| predictor = _get_predictor() | |
| record = _build_record( | |
| anomaly_level, demand_loss_mw, res_price, com_price, ind_price, | |
| total_price, total_sales, total_customers, population, poppct_urban, | |
| popden_urban, popden_rural, climate_region, climate_category, | |
| cause_category, nerc_region, int(month), | |
| ) | |
| result = predictor.predict_single(record) | |
| # Get SHAP factors | |
| from src.explain import explain_prediction | |
| from src.features import engineer_features | |
| df = pd.DataFrame([record]) | |
| df = engineer_features(df) | |
| # Ensure all columns expected by the preprocessor are present (fill with NaN if missing) | |
| expected_cols = getattr(predictor.preprocessor, "feature_names_in_", []) | |
| for col in expected_cols: | |
| if col not in df.columns: | |
| df[col] = np.nan | |
| X = predictor.preprocessor.transform(df) | |
| shap_factors = explain_prediction( | |
| X, model=predictor.model, feature_names=predictor.feature_names, top_k=5 | |
| ) | |
| prompt = f"""You are a grid reliability analyst explaining an AI risk prediction to a | |
| non-technical utility operations manager. Be concise (3-4 sentences max). | |
| Prediction: {result['probability']:.1%} probability of high-impact outage β {result['risk_tier']} tier. | |
| Top contributing factors: | |
| {json.dumps(shap_factors, indent=2)} | |
| Input context: | |
| - Climate region: {climate_region} | |
| - Cause category: {cause_category} | |
| - Month: {month} | |
| - Population: {population:,.0f} | |
| - Anomaly level (ONI): {anomaly_level} | |
| Explain what is driving this risk score and what the operations team should watch for. | |
| Do NOT mention SHAP, ML models, or technical jargon.""" | |
| try: | |
| response = model.generate_content(prompt) | |
| return f"### π‘ AI Explanation\n\n{response.text}" | |
| except Exception as e: | |
| return f"β οΈ Gemini API error: {e}" | |
| # --------------------------------------------------------------------------- | |
| # Gradio interface | |
| # --------------------------------------------------------------------------- | |
| CLIMATE_REGIONS = [ | |
| "East North Central", "Central", "Northeast", "Northwest", | |
| "South", "Southeast", "Southwest", "West", "West North Central", | |
| ] | |
| CLIMATE_CATEGORIES = ["normal", "warm", "cold"] | |
| CAUSE_CATEGORIES = [ | |
| "severe weather", "intentional attack", "system operability disruption", | |
| "public appeal", "equipment failure", "fuel supply emergency", "islanding", | |
| ] | |
| NERC_REGIONS = ["RFC", "SERC", "WECC", "TRE", "NPCC", "MRO", "SPP", "FRCC", "ECAR", "HECO"] | |
| def build_app() -> gr.Blocks: | |
| with gr.Blocks( | |
| title="Grid Risk & Reliability Platform", | |
| theme=gr.themes.Soft(), | |
| ) as app: | |
| gr.Markdown( | |
| "# β‘ Grid Risk & Reliability Platform\n" | |
| "Predict the probability of a high-impact power outage event. " | |
| "All inputs are optional β the model handles missing values." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Scenario Inputs") | |
| anomaly_level = gr.Number(label="Anomaly Level (ONI)", value=-0.3) | |
| demand_loss_mw = gr.Number(label="Demand Loss (MW)", value=250.0) | |
| month = gr.Slider(1, 12, step=1, value=7, label="Month") | |
| with gr.Row(): | |
| res_price = gr.Number(label="Res. Price", value=11.6) | |
| com_price = gr.Number(label="Com. Price", value=9.5) | |
| ind_price = gr.Number(label="Ind. Price", value=6.7) | |
| total_price = gr.Number(label="Total Price", value=9.3) | |
| total_sales = gr.Number(label="Total Sales (MWh)", value=6.5e7) | |
| total_customers = gr.Number(label="Total Customers", value=2.5e6) | |
| population = gr.Number(label="State Population", value=5.8e6) | |
| poppct_urban = gr.Number(label="Urban Pop %", value=73.0) | |
| popden_urban = gr.Number(label="Urban Density", value=2200.0) | |
| popden_rural = gr.Number(label="Rural Density", value=18.0) | |
| climate_region = gr.Dropdown(CLIMATE_REGIONS, label="Climate Region", value="East North Central") | |
| climate_category = gr.Dropdown(CLIMATE_CATEGORIES, label="Climate Category", value="normal") | |
| cause_category = gr.Dropdown(CAUSE_CATEGORIES, label="Cause Category", value="severe weather") | |
| nerc_region = gr.Dropdown(NERC_REGIONS, label="NERC Region", value="RFC") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Prediction Results") | |
| risk_output = gr.Markdown(label="Risk Score") | |
| shap_output = gr.Markdown(label="Top Risk Factors (SHAP)") | |
| predict_btn = gr.Button("π Predict Risk", variant="primary", size="lg") | |
| gr.Markdown("---") | |
| gr.Markdown("### Plain-Language Explanation (Gemini)") | |
| gemini_output = gr.Markdown(label="AI Explanation") | |
| gemini_btn = gr.Button("π¬ Generate Explanation", variant="secondary") | |
| all_inputs = [ | |
| anomaly_level, demand_loss_mw, res_price, com_price, ind_price, | |
| total_price, total_sales, total_customers, population, poppct_urban, | |
| popden_urban, popden_rural, climate_region, climate_category, | |
| cause_category, nerc_region, month, | |
| ] | |
| predict_btn.click(fn=predict_risk, inputs=all_inputs, outputs=[risk_output, shap_output]) | |
| gemini_btn.click(fn=generate_gemini_explanation, inputs=all_inputs, outputs=gemini_output) | |
| return app | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch(server_name="0.0.0.0", server_port=7860) | |