Nashid-Noor
Change Gemini model to gemini-2.5-flash
7d35da3
"""
Gradio interface for the Grid Risk & Reliability Platform.
Usage
-----
python -m src.app # launches on localhost:7860
GEMINI_API_KEY=... python -m src.app # enables plain-language explanations
Features
--------
β€’ Input form with key outage-related features
β€’ Risk score + tier badge
β€’ SHAP top contributing factors table
β€’ Optional Gemini-powered plain-language summary
"""
from __future__ import annotations
import json
import logging
import os
import sys
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import numpy as np
import pandas as pd
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Lazy-loaded singletons (avoids loading model on import)
# ---------------------------------------------------------------------------
@lru_cache(maxsize=1)
def _get_predictor():
from src.predict import GridRiskPredictor
return GridRiskPredictor()
@lru_cache(maxsize=1)
def _get_feature_names() -> List[str]:
from src.config import ARTIFACTS_DIR, FEATURE_NAMES_FILE
with open(ARTIFACTS_DIR / FEATURE_NAMES_FILE) as f:
return json.load(f)
def _get_gemini_model():
"""Return a Gemini GenerativeModel or None if API key is absent."""
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
return None
try:
import google.generativeai as genai
genai.configure(api_key=api_key)
return genai.GenerativeModel("gemini-2.5-flash")
except Exception as e:
logger.warning("Gemini init failed: %s", e)
return None
# ---------------------------------------------------------------------------
# Core logic
# ---------------------------------------------------------------------------
TIER_COLORS = {
"CRITICAL": "πŸ”΄",
"HIGH": "🟠",
"MODERATE": "🟑",
"LOW": "🟒",
}
def _build_record(
anomaly_level: float,
demand_loss_mw: float,
res_price: float,
com_price: float,
ind_price: float,
total_price: float,
total_sales: float,
total_customers: float,
population: float,
poppct_urban: float,
popden_urban: float,
popden_rural: float,
climate_region: str,
climate_category: str,
cause_category: str,
nerc_region: str,
month: int,
) -> Dict[str, Any]:
"""Pack UI inputs into the dict format expected by the predictor."""
return {
"ANOMALY.LEVEL": anomaly_level,
"DEMAND.LOSS.MW": demand_loss_mw,
"RES.PRICE": res_price,
"COM.PRICE": com_price,
"IND.PRICE": ind_price,
"TOTAL.PRICE": total_price,
"TOTAL.SALES": total_sales,
"TOTAL.CUSTOMERS": total_customers,
"POPULATION": population,
"POPPCT_URBAN": poppct_urban,
"POPDEN_URBAN": popden_urban,
"POPDEN_RURAL": popden_rural,
"CLIMATE.REGION": climate_region,
"CLIMATE.CATEGORY": climate_category,
"CAUSE.CATEGORY": cause_category,
"NERC.REGION": nerc_region,
"MONTH": month,
}
def predict_risk(
anomaly_level, demand_loss_mw, res_price, com_price, ind_price,
total_price, total_sales, total_customers, population, poppct_urban,
popden_urban, popden_rural, climate_region, climate_category,
cause_category, nerc_region, month,
) -> Tuple[str, str]:
"""Run prediction and return (risk_summary, shap_table_markdown)."""
predictor = _get_predictor()
record = _build_record(
anomaly_level, demand_loss_mw, res_price, com_price, ind_price,
total_price, total_sales, total_customers, population, poppct_urban,
popden_urban, popden_rural, climate_region, climate_category,
cause_category, nerc_region, int(month),
)
result = predictor.predict_single(record)
prob = result["probability"]
tier = result["risk_tier"]
icon = TIER_COLORS.get(tier, "βšͺ")
# SHAP explanation
from src.explain import explain_prediction
df = pd.DataFrame([record])
from src.features import engineer_features
df = engineer_features(df)
# Ensure all columns expected by the preprocessor are present (fill with NaN if missing)
expected_cols = getattr(predictor.preprocessor, "feature_names_in_", [])
for col in expected_cols:
if col not in df.columns:
df[col] = np.nan
X = predictor.preprocessor.transform(df)
shap_factors = explain_prediction(
X, model=predictor.model, feature_names=predictor.feature_names, top_k=8
)
# Format outputs
summary = (
f"## {icon} Risk Tier: **{tier}**\n\n"
f"**Probability of high-impact outage:** {prob:.1%}\n\n"
f"Threshold: β‰₯50% β†’ positive prediction"
)
table_rows = ["| Feature | SHAP Value | Direction |", "|---|---|---|"]
for f in shap_factors:
table_rows.append(f"| {f['feature']} | {f['shap_value']:+.4f} | {f['direction']} |")
shap_table = "\n".join(table_rows)
return summary, shap_table
def generate_gemini_explanation(
anomaly_level, demand_loss_mw, res_price, com_price, ind_price,
total_price, total_sales, total_customers, population, poppct_urban,
popden_urban, popden_rural, climate_region, climate_category,
cause_category, nerc_region, month,
) -> str:
"""Call Gemini to produce a plain-language explanation of the risk."""
model = _get_gemini_model()
if model is None:
return (
"⚠️ Gemini API key not configured.\n\n"
"Set the `GEMINI_API_KEY` environment variable and restart the app."
)
# Run the prediction first to get the numbers
predictor = _get_predictor()
record = _build_record(
anomaly_level, demand_loss_mw, res_price, com_price, ind_price,
total_price, total_sales, total_customers, population, poppct_urban,
popden_urban, popden_rural, climate_region, climate_category,
cause_category, nerc_region, int(month),
)
result = predictor.predict_single(record)
# Get SHAP factors
from src.explain import explain_prediction
from src.features import engineer_features
df = pd.DataFrame([record])
df = engineer_features(df)
# Ensure all columns expected by the preprocessor are present (fill with NaN if missing)
expected_cols = getattr(predictor.preprocessor, "feature_names_in_", [])
for col in expected_cols:
if col not in df.columns:
df[col] = np.nan
X = predictor.preprocessor.transform(df)
shap_factors = explain_prediction(
X, model=predictor.model, feature_names=predictor.feature_names, top_k=5
)
prompt = f"""You are a grid reliability analyst explaining an AI risk prediction to a
non-technical utility operations manager. Be concise (3-4 sentences max).
Prediction: {result['probability']:.1%} probability of high-impact outage β†’ {result['risk_tier']} tier.
Top contributing factors:
{json.dumps(shap_factors, indent=2)}
Input context:
- Climate region: {climate_region}
- Cause category: {cause_category}
- Month: {month}
- Population: {population:,.0f}
- Anomaly level (ONI): {anomaly_level}
Explain what is driving this risk score and what the operations team should watch for.
Do NOT mention SHAP, ML models, or technical jargon."""
try:
response = model.generate_content(prompt)
return f"### πŸ’‘ AI Explanation\n\n{response.text}"
except Exception as e:
return f"⚠️ Gemini API error: {e}"
# ---------------------------------------------------------------------------
# Gradio interface
# ---------------------------------------------------------------------------
CLIMATE_REGIONS = [
"East North Central", "Central", "Northeast", "Northwest",
"South", "Southeast", "Southwest", "West", "West North Central",
]
CLIMATE_CATEGORIES = ["normal", "warm", "cold"]
CAUSE_CATEGORIES = [
"severe weather", "intentional attack", "system operability disruption",
"public appeal", "equipment failure", "fuel supply emergency", "islanding",
]
NERC_REGIONS = ["RFC", "SERC", "WECC", "TRE", "NPCC", "MRO", "SPP", "FRCC", "ECAR", "HECO"]
def build_app() -> gr.Blocks:
with gr.Blocks(
title="Grid Risk & Reliability Platform",
theme=gr.themes.Soft(),
) as app:
gr.Markdown(
"# ⚑ Grid Risk & Reliability Platform\n"
"Predict the probability of a high-impact power outage event. "
"All inputs are optional β€” the model handles missing values."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Scenario Inputs")
anomaly_level = gr.Number(label="Anomaly Level (ONI)", value=-0.3)
demand_loss_mw = gr.Number(label="Demand Loss (MW)", value=250.0)
month = gr.Slider(1, 12, step=1, value=7, label="Month")
with gr.Row():
res_price = gr.Number(label="Res. Price", value=11.6)
com_price = gr.Number(label="Com. Price", value=9.5)
ind_price = gr.Number(label="Ind. Price", value=6.7)
total_price = gr.Number(label="Total Price", value=9.3)
total_sales = gr.Number(label="Total Sales (MWh)", value=6.5e7)
total_customers = gr.Number(label="Total Customers", value=2.5e6)
population = gr.Number(label="State Population", value=5.8e6)
poppct_urban = gr.Number(label="Urban Pop %", value=73.0)
popden_urban = gr.Number(label="Urban Density", value=2200.0)
popden_rural = gr.Number(label="Rural Density", value=18.0)
climate_region = gr.Dropdown(CLIMATE_REGIONS, label="Climate Region", value="East North Central")
climate_category = gr.Dropdown(CLIMATE_CATEGORIES, label="Climate Category", value="normal")
cause_category = gr.Dropdown(CAUSE_CATEGORIES, label="Cause Category", value="severe weather")
nerc_region = gr.Dropdown(NERC_REGIONS, label="NERC Region", value="RFC")
with gr.Column(scale=1):
gr.Markdown("### Prediction Results")
risk_output = gr.Markdown(label="Risk Score")
shap_output = gr.Markdown(label="Top Risk Factors (SHAP)")
predict_btn = gr.Button("πŸ” Predict Risk", variant="primary", size="lg")
gr.Markdown("---")
gr.Markdown("### Plain-Language Explanation (Gemini)")
gemini_output = gr.Markdown(label="AI Explanation")
gemini_btn = gr.Button("πŸ’¬ Generate Explanation", variant="secondary")
all_inputs = [
anomaly_level, demand_loss_mw, res_price, com_price, ind_price,
total_price, total_sales, total_customers, population, poppct_urban,
popden_urban, popden_rural, climate_region, climate_category,
cause_category, nerc_region, month,
]
predict_btn.click(fn=predict_risk, inputs=all_inputs, outputs=[risk_output, shap_output])
gemini_btn.click(fn=generate_gemini_explanation, inputs=all_inputs, outputs=gemini_output)
return app
if __name__ == "__main__":
app = build_app()
app.launch(server_name="0.0.0.0", server_port=7860)