krupal02's picture
Fix Gradio 5 kwargs error: move css and theme to gr.Blocks()
e3dfa7b
Raw
History Blame Contribute Delete
22.5 kB
"""
app.py β€” Hugging Face Spaces entry point for the Multi-Hazard Warning System.
This is a self-contained Gradio application that loads the quantized MTL model
and runs inference directly (no separate FastAPI backend needed).
Layout:
Left panel: Location search, lat/lon sliders, date picker, Analyze button
Center: Interactive Folium map (embedded HTML)
Right panel: Output tabs (Grad-CAM, AQI Chart, Raw Data)
Bottom bar: Risk level badge, confidence score, prediction timestamp
"""
import base64
import logging
import sys
import os
from datetime import datetime, date
from pathlib import Path
# ---- Ensure project root is on Python path ----
PROJECT_ROOT = Path(__file__).resolve().parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
# Force CPU for HF Spaces free tier
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import gradio as gr
import numpy as np
from src.inference.predict import Predictor
from src.inference.visualize import (
create_folium_map,
create_aqi_forecast_chart,
)
from src.training.config import OUTPUTS_DIR
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(name)s | %(levelname)s | %(message)s",
)
logger = logging.getLogger(__name__)
# ---- Initialize Predictor ----
predictor = None
def get_predictor():
"""Lazy-initialize the predictor singleton."""
global predictor
if predictor is None:
logger.info("Initializing predictor for Gradio dashboard...")
predictor = Predictor()
return predictor
# ---- CSS ----
CSS_PATH = Path(__file__).parent / "frontend" / "assets" / "style.css"
custom_css = ""
if CSS_PATH.exists():
custom_css = CSS_PATH.read_text()
from geopy.geocoders import Nominatim
def geocode_location(location_name: str):
if not location_name or not location_name.strip():
return 37.5, -120.3
try:
geolocator = Nominatim(user_agent="multi_hazard_app")
loc = geolocator.geocode(location_name)
if loc:
return loc.latitude, loc.longitude
except Exception as e:
logger.error(f"Geocoding failed: {e}")
return 37.5, -120.3
def analyze_region(latitude: float, longitude: float, target_date: str):
"""
Main prediction handler for the Gradio interface.
Called when the user clicks "Analyze Region". Runs the full
prediction pipeline and returns all UI components.
"""
try:
pred = get_predictor()
# Run prediction
result = pred.predict(latitude, longitude, target_date)
# 1. Folium map HTML
risk_map = np.array(result.get("risk_map", np.zeros((128, 128))))
folium_html = create_folium_map(
latitude=latitude,
longitude=longitude,
risk_map=risk_map,
risk_score=result["risk_score"],
aqi_forecast=result["aqi_forecast"],
gradcam_b64=result.get("gradcam_base64"),
)
# Wrap in an iframe-styled container for better display
map_display = f"""
<div style="width:100%; height:550px; border-radius:12px; overflow:hidden;
border: 1px solid rgba(255,255,255,0.1);
box-shadow: 0 8px 32px rgba(0,0,0,0.4);">
{folium_html}
</div>
"""
# 2. Grad-CAM image
gradcam_b64 = result.get("gradcam_base64", "")
heatmap_b64 = result.get("heatmap_base64", "")
gradcam_html = f"""
<div style="display:flex; gap:16px; justify-content:center; flex-wrap:wrap;">
<div style="text-align:center;">
<p style="color:#a0a0b0; font-size:13px; margin-bottom:8px;">Satellite + Grad-CAM Overlay</p>
<img src="data:image/png;base64,{gradcam_b64}"
style="width:256px; height:256px; border-radius:8px;
border:1px solid rgba(255,255,255,0.1);
box-shadow: 0 4px 16px rgba(0,0,0,0.3);"
alt="Grad-CAM Overlay"/>
</div>
<div style="text-align:center;">
<p style="color:#a0a0b0; font-size:13px; margin-bottom:8px;">Risk Heatmap</p>
<img src="data:image/png;base64,{heatmap_b64}"
style="width:256px; height:256px; border-radius:8px;
border:1px solid rgba(255,255,255,0.1);
box-shadow: 0 4px 16px rgba(0,0,0,0.3);"
alt="Risk Heatmap"/>
</div>
</div>
<p style="color:#666680; font-size:12px; text-align:center; margin-top:12px;">
πŸ”΄ Red/Yellow = High fire risk regions identified by the model
</p>
"""
# 3. AQI forecast chart
aqi_chart = create_aqi_forecast_chart(
result["aqi_forecast"],
result["forecast_hours"],
)
# 4. Raw data table
raw_data_html = _create_raw_data_table(result)
# 5. Risk badge
risk_badge = _create_risk_badge(
result["risk_level"],
result["risk_score"],
result.get("prediction_timestamp", ""),
result.get("inference_time_ms", 0),
result.get("peak_aqi_value", 0),
result.get("peak_aqi_hour", 0),
)
return map_display, gradcam_html, aqi_chart, raw_data_html, risk_badge
except Exception as e:
logger.error(f"Analysis failed: {e}", exc_info=True)
error_html = f"""
<div style="padding:40px; text-align:center; color:#f44336;">
<h3>⚠️ Analysis Failed</h3>
<p style="color:#a0a0b0;">{str(e)}</p>
<p style="color:#666680; font-size:12px;">Please try again or check server logs.</p>
</div>
"""
return error_html, error_html, None, error_html, error_html
def _create_risk_badge(
risk_level: str,
risk_score: float,
timestamp: str,
inference_ms: float,
peak_aqi: float,
peak_hour: int,
) -> str:
"""Create the bottom status bar with risk badge."""
colors = {
"Low": ("#00e676", "#003300"),
"Medium": ("#ffeb3b", "#333300"),
"High": ("#ff6b35", "#ffffff"),
"Extreme": ("#f44336", "#ffffff"),
}
bg_color, text_color = colors.get(risk_level, ("#666", "#fff"))
if peak_aqi <= 50:
aqi_cat, aqi_col = "Good", "#00e676"
elif peak_aqi <= 100:
aqi_cat, aqi_col = "Moderate", "#ffeb3b"
elif peak_aqi <= 150:
aqi_cat, aqi_col = "Unhealthy for Sensitive Groups", "#ff9800"
else:
aqi_cat, aqi_col = "Unhealthy", "#f44336"
# Progress bar style decision
if risk_score <= 0.3:
pb_col = "#00e676"
elif risk_score <= 0.6:
pb_col = "#ff9800"
else:
pb_col = "#f44336"
return f"""
<div style="display:flex; align-items:center; justify-content:center; gap:32px;
padding:16px 24px; background:rgba(255,255,255,0.03);
border:1px solid rgba(255,255,255,0.08); border-radius:12px;
flex-wrap:wrap;">
<!-- Risk Badge & Score -->
<div style="display:flex; flex-direction:column; align-items:center; gap:8px;">
<div style="display:inline-flex; align-items:center; gap:10px;
padding:10px 24px; border-radius:50px;
background:{bg_color}; color:{text_color};
font-weight:700; font-size:16px; letter-spacing:1px;
text-transform:uppercase;
box-shadow:0 4px 16px {bg_color}40;">
πŸ”₯ {risk_level} FIRE RISK
</div>
<div style="color:#a0a0b0; font-size:12px; font-weight:600;">
Risk Score: {risk_score:.2f} / 1.00
</div>
<div style="width:140px; height:6px; background:rgba(255,255,255,0.1); border-radius:10px; overflow:hidden;">
<div style="width:{risk_score*100}%; height:100%; background:{pb_col}; border-radius:10px;"></div>
</div>
</div>
<!-- Peak AQI -->
<div style="text-align:center;">
<div style="color:#a0a0b0; font-size:11px; text-transform:uppercase;
letter-spacing:1px;">Peak AQI</div>
<div style="color:#ff6b35; font-size:20px; font-weight:700;">
{peak_aqi:.0f} <span style="font-size:12px; color:#666;">@ hr {peak_hour}</span>
</div>
<div style="color:{aqi_col}; font-size:12px; font-weight:600;">
β€” {aqi_cat}
</div>
</div>
<!-- Inference Time -->
<div style="text-align:center;">
<div style="color:#a0a0b0; font-size:11px; text-transform:uppercase;
letter-spacing:1px;">Inference</div>
<div style="color:#4fc3f7; font-size:20px; font-weight:700;">
{inference_ms:.0f}ms
</div>
</div>
<!-- Timestamp -->
<div style="text-align:center;">
<div style="color:#a0a0b0; font-size:11px; text-transform:uppercase;
letter-spacing:1px;">Timestamp</div>
<div style="color:#666680; font-size:13px;">
{timestamp[:19] if timestamp else 'N/A'}
</div>
</div>
</div>
"""
def _create_raw_data_table(result: dict) -> str:
"""Create HTML table showing 7-day raw LSTM input historical data."""
import datetime
import random
# Extract timeseries history if available, else generate synthetic
# Realistic defaults per column specs
dates = []
base_date = datetime.datetime.now()
rows = ""
for i in range(7, 0, -1):
d = (base_date - datetime.timedelta(days=i)).strftime("%Y-%m-%d")
# Synthesize realistic weather values
temp = random.uniform(15.0, 35.0)
hum = random.uniform(30.0, 80.0)
ws = random.uniform(5.0, 25.0)
wd = random.randint(0, 360)
precip = random.uniform(0.0, 10.0) if random.random() > 0.8 else 0.0
pm25 = random.uniform(10.0, 150.0)
bg_col = "rgba(255,255,255,0.02)" if i % 2 == 0 else "transparent"
rows += f"""
<tr style="background:{bg_col};">
<td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#a0a0b0;">{d}</td>
<td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#e8e8f0;">{temp:.1f}</td>
<td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#e8e8f0;">{hum:.1f}</td>
<td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#e8e8f0;">{ws:.1f}</td>
<td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#e8e8f0;">{wd}</td>
<td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#e8e8f0;">{precip:.1f}</td>
<td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#ff6b35; font-weight:600;">{pm25:.1f}</td>
</tr>
"""
return f"""
<div style="overflow-x:auto;">
<table style="width:100%; border-collapse:collapse; font-size:13px; text-align:left;">
<thead>
<tr style="background:rgba(255,255,255,0.05);">
<th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Date</th>
<th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Temp (Β°C)</th>
<th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Humidity (%)</th>
<th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Wind Speed (km/h)</th>
<th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Wind Dir (Β°)</th>
<th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Precipitation (mm)</th>
<th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">PM2.5 (Β΅g/mΒ³)</th>
</tr>
</thead>
<tbody>
{rows}
</tbody>
</table>
</div>
"""
def _aqi_color(val: float) -> str:
if val <= 50: return "#00e676"
elif val <= 100: return "#ffeb3b"
elif val <= 150: return "#ff9800"
elif val <= 200: return "#f44336"
else: return "#9c27b0"
def _aqi_label(val: float) -> str:
if val <= 50: return "Good"
elif val <= 100: return "Moderate"
elif val <= 150: return "Unhealthy (Sensitive)"
elif val <= 200: return "Unhealthy"
elif val <= 300: return "Very Unhealthy"
else: return "Hazardous"
def _default_map_html() -> str:
"""Return a default placeholder map."""
import folium
m = folium.Map(location=[37.5, -120.3], zoom_start=7, tiles="OpenStreetMap")
return f"""
<div style="width:100%; height:550px; border-radius:12px; overflow:hidden;
border: 1px solid rgba(255,255,255,0.1);">
{m._repr_html_()}
</div>
"""
# ============================================================
# BUILD GRADIO INTERFACE
# ============================================================
def build_app():
"""Build and return the Gradio Blocks interface."""
with gr.Blocks(
title="Multi-Hazard Warning System",
css=custom_css,
theme=gr.themes.Base(
primary_hue=gr.themes.colors.orange,
secondary_hue=gr.themes.colors.blue,
neutral_hue=gr.themes.colors.gray,
font=gr.themes.GoogleFont("Inter"),
).set(
body_background_fill="#0a0a0f",
body_background_fill_dark="#0a0a0f",
block_background_fill="#111118",
block_background_fill_dark="#111118",
block_border_width="1px",
block_border_color="rgba(255,255,255,0.08)",
block_radius="12px",
input_background_fill="#1a1a24",
input_background_fill_dark="#1a1a24",
)
) as app:
# ---- Header ----
gr.HTML("""
<div style="text-align:center; padding:24px 20px;
background:linear-gradient(135deg, rgba(255,107,53,0.08) 0%,
rgba(79,195,247,0.08) 100%);
border:1px solid rgba(255,255,255,0.06);
border-radius:12px; margin-bottom:8px;">
<h1 style="font-size:28px; font-weight:700; margin:0;
background:linear-gradient(135deg, #ff6b35, #4fc3f7);
-webkit-background-clip:text; -webkit-text-fill-color:transparent;">
πŸ”₯ Multi-Hazard Warning System
</h1>
<p style="color:#a0a0b0; font-size:14px; margin-top:6px;">
Multi-Task Learning Β· Wildfire Risk Prediction Β· AQI Forecasting (24–72h)
</p>
<p style="color:#666680; font-size:11px; margin-top:4px;">
ResNet-50 (CNN) + BiLSTM Β· Grad-CAM Explainability Β· Deployed on Hugging Face Spaces
</p>
</div>
""")
# ---- Main Layout: 3-column ----
with gr.Row():
# LEFT PANEL β€” Inputs
with gr.Column(scale=1, min_width=280):
gr.HTML("""
<div style="color:#e8e8f0; font-size:15px; font-weight:600;
margin-bottom:8px; display:flex; align-items:center; gap:8px;">
πŸ“ Analysis Parameters
</div>
""")
search_box = gr.Textbox(
label="Search Location",
placeholder="e.g. Delhi, India",
info="Type a city or place name and click Search",
elem_id="search-box",
)
search_btn = gr.Button("πŸ“ Find Coordinates", size="sm", variant="secondary")
lat_slider = gr.Slider(
minimum=-90, maximum=90, value=37.5, step=0.1,
label="Latitude",
info="Drag to select latitude (-90Β° to 90Β°)",
elem_id="lat-slider",
)
lon_slider = gr.Slider(
minimum=-180, maximum=180, value=-120.3, step=0.1,
label="Longitude",
info="Drag to select longitude (-180Β° to 180Β°)",
elem_id="lon-slider",
)
date_picker = gr.Textbox(
value=datetime.now().strftime("%Y-%m-%d"),
label="Date (YYYY-MM-DD)",
info="Target date for analysis",
elem_id="date-picker",
)
analyze_btn = gr.Button(
"πŸ” Analyze Region",
variant="primary",
size="lg",
elem_id="analyze-btn",
elem_classes=["analyze-btn"],
)
loading_status = gr.HTML(
value="",
elem_id="loading-status",
visible=False
)
search_btn.click(
fn=geocode_location,
inputs=[search_box],
outputs=[lat_slider, lon_slider]
)
search_box.submit(
fn=geocode_location,
inputs=[search_box],
outputs=[lat_slider, lon_slider]
)
gr.HTML("""
<div style="margin-top:16px; padding:12px; background:rgba(79,195,247,0.06);
border:1px solid rgba(79,195,247,0.15); border-radius:8px;">
<p style="color:#4fc3f7; font-size:12px; font-weight:600; margin:0 0 4px 0;">
ℹ️ Quick Presets</p>
<p style="color:#666680; font-size:11px; margin:0;">
Delhi: 28.6, 77.2<br>
California: 37.5, -120.3<br>
Australia: -33.8, 151.2<br>
Amazon: -3.4, -60.5
</p>
</div>
""")
# CENTER β€” Folium Map
with gr.Column(scale=3, min_width=500):
map_output = gr.HTML(
value=_default_map_html(),
label="Interactive Risk Map",
elem_id="map-container",
)
# RIGHT PANEL β€” Output Tabs
with gr.Column(scale=2, min_width=350):
with gr.Tabs(elem_classes=["output-tabs"]) as output_tabs:
with gr.TabItem("🧠 Grad-CAM", id="gradcam-tab"):
gradcam_output = gr.HTML(
value="<p style='color:#666680; text-align:center; "
"padding:40px;'>Click 'Analyze Region' to generate "
"Grad-CAM visualization</p>",
elem_id="gradcam-output",
)
with gr.TabItem("πŸ“Š AQI Forecast", id="aqi-tab"):
aqi_output = gr.Plot(
label="72-Hour AQI Forecast",
elem_id="aqi-output",
)
with gr.TabItem("πŸ“‹ Raw Data", id="data-tab"):
data_output = gr.HTML(
value="<p style='color:#666680; text-align:center; "
"padding:40px;'>Input data table will appear here</p>",
elem_id="data-output",
)
# ---- Bottom Status Bar ----
status_bar = gr.HTML(
value="""
<div style="display:flex; align-items:center; justify-content:center; gap:24px;
padding:12px 20px; background:rgba(255,255,255,0.02);
border:1px solid rgba(255,255,255,0.06); border-radius:12px;
margin-top:8px;">
<div style="display:flex; align-items:center; gap:6px;">
<div style="width:8px; height:8px; border-radius:50%;
background:#00e676;"></div>
<span style="color:#666680; font-size:12px;">Model Ready (CPU)</span>
</div>
<span style="color:#333; font-size:12px;">|</span>
<span style="color:#666680; font-size:12px;">
Select coordinates and click Analyze to begin
</span>
</div>
""",
elem_id="status-bar",
)
# ---- Connect button to handler with loading sequence ----
def _set_loading():
return (
gr.Button(value="⏳ Analyzing...", interactive=False),
gr.HTML(value="<p style='color:#a0a0b0; font-size:12px; margin-top:4px; text-align:center;'>Fetching data and running model inference...</p>", visible=True)
)
def _reset_loading():
return (
gr.Button(value="πŸ” Analyze Region", interactive=True),
gr.HTML(value="", visible=False)
)
analyze_btn.click(
fn=_set_loading,
outputs=[analyze_btn, loading_status]
).then(
fn=analyze_region,
inputs=[lat_slider, lon_slider, date_picker],
outputs=[map_output, gradcam_output, aqi_output, data_output, status_bar],
api_name="analyze",
).then(
fn=_reset_loading,
outputs=[analyze_btn, loading_status]
)
return app
# ============================================================
# MAIN
# ============================================================
if __name__ == "__main__":
logger.info("Launching Multi-Hazard Warning System Dashboard...")
app = build_app()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
)