| """ |
| app.py β Hugging Face Spaces entry point for the Multi-Hazard Warning System. |
| |
| This is a self-contained Gradio application that loads the quantized MTL model |
| and runs inference directly (no separate FastAPI backend needed). |
| |
| Layout: |
| Left panel: Location search, lat/lon sliders, date picker, Analyze button |
| Center: Interactive Folium map (embedded HTML) |
| Right panel: Output tabs (Grad-CAM, AQI Chart, Raw Data) |
| Bottom bar: Risk level badge, confidence score, prediction timestamp |
| """ |
|
|
| import base64 |
| import logging |
| import sys |
| import os |
| from datetime import datetime, date |
| from pathlib import Path |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parent |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" |
|
|
| import gradio as gr |
| import numpy as np |
|
|
| from src.inference.predict import Predictor |
| from src.inference.visualize import ( |
| create_folium_map, |
| create_aqi_forecast_chart, |
| ) |
| from src.training.config import OUTPUTS_DIR |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s | %(name)s | %(levelname)s | %(message)s", |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| predictor = None |
|
|
|
|
| def get_predictor(): |
| """Lazy-initialize the predictor singleton.""" |
| global predictor |
| if predictor is None: |
| logger.info("Initializing predictor for Gradio dashboard...") |
| predictor = Predictor() |
| return predictor |
|
|
|
|
| |
| CSS_PATH = Path(__file__).parent / "frontend" / "assets" / "style.css" |
| custom_css = "" |
| if CSS_PATH.exists(): |
| custom_css = CSS_PATH.read_text() |
|
|
|
|
| from geopy.geocoders import Nominatim |
|
|
| def geocode_location(location_name: str): |
| if not location_name or not location_name.strip(): |
| return 37.5, -120.3 |
| try: |
| geolocator = Nominatim(user_agent="multi_hazard_app") |
| loc = geolocator.geocode(location_name) |
| if loc: |
| return loc.latitude, loc.longitude |
| except Exception as e: |
| logger.error(f"Geocoding failed: {e}") |
| return 37.5, -120.3 |
|
|
|
|
| def analyze_region(latitude: float, longitude: float, target_date: str): |
| """ |
| Main prediction handler for the Gradio interface. |
| |
| Called when the user clicks "Analyze Region". Runs the full |
| prediction pipeline and returns all UI components. |
| """ |
| try: |
| pred = get_predictor() |
|
|
| |
| result = pred.predict(latitude, longitude, target_date) |
|
|
| |
| risk_map = np.array(result.get("risk_map", np.zeros((128, 128)))) |
| folium_html = create_folium_map( |
| latitude=latitude, |
| longitude=longitude, |
| risk_map=risk_map, |
| risk_score=result["risk_score"], |
| aqi_forecast=result["aqi_forecast"], |
| gradcam_b64=result.get("gradcam_base64"), |
| ) |
|
|
| |
| map_display = f""" |
| <div style="width:100%; height:550px; border-radius:12px; overflow:hidden; |
| border: 1px solid rgba(255,255,255,0.1); |
| box-shadow: 0 8px 32px rgba(0,0,0,0.4);"> |
| {folium_html} |
| </div> |
| """ |
|
|
| |
| gradcam_b64 = result.get("gradcam_base64", "") |
| heatmap_b64 = result.get("heatmap_base64", "") |
|
|
| gradcam_html = f""" |
| <div style="display:flex; gap:16px; justify-content:center; flex-wrap:wrap;"> |
| <div style="text-align:center;"> |
| <p style="color:#a0a0b0; font-size:13px; margin-bottom:8px;">Satellite + Grad-CAM Overlay</p> |
| <img src="data:image/png;base64,{gradcam_b64}" |
| style="width:256px; height:256px; border-radius:8px; |
| border:1px solid rgba(255,255,255,0.1); |
| box-shadow: 0 4px 16px rgba(0,0,0,0.3);" |
| alt="Grad-CAM Overlay"/> |
| </div> |
| <div style="text-align:center;"> |
| <p style="color:#a0a0b0; font-size:13px; margin-bottom:8px;">Risk Heatmap</p> |
| <img src="data:image/png;base64,{heatmap_b64}" |
| style="width:256px; height:256px; border-radius:8px; |
| border:1px solid rgba(255,255,255,0.1); |
| box-shadow: 0 4px 16px rgba(0,0,0,0.3);" |
| alt="Risk Heatmap"/> |
| </div> |
| </div> |
| <p style="color:#666680; font-size:12px; text-align:center; margin-top:12px;"> |
| π΄ Red/Yellow = High fire risk regions identified by the model |
| </p> |
| """ |
|
|
| |
| aqi_chart = create_aqi_forecast_chart( |
| result["aqi_forecast"], |
| result["forecast_hours"], |
| ) |
|
|
| |
| raw_data_html = _create_raw_data_table(result) |
|
|
| |
| risk_badge = _create_risk_badge( |
| result["risk_level"], |
| result["risk_score"], |
| result.get("prediction_timestamp", ""), |
| result.get("inference_time_ms", 0), |
| result.get("peak_aqi_value", 0), |
| result.get("peak_aqi_hour", 0), |
| ) |
|
|
| return map_display, gradcam_html, aqi_chart, raw_data_html, risk_badge |
|
|
| except Exception as e: |
| logger.error(f"Analysis failed: {e}", exc_info=True) |
| error_html = f""" |
| <div style="padding:40px; text-align:center; color:#f44336;"> |
| <h3>β οΈ Analysis Failed</h3> |
| <p style="color:#a0a0b0;">{str(e)}</p> |
| <p style="color:#666680; font-size:12px;">Please try again or check server logs.</p> |
| </div> |
| """ |
| return error_html, error_html, None, error_html, error_html |
|
|
|
|
| def _create_risk_badge( |
| risk_level: str, |
| risk_score: float, |
| timestamp: str, |
| inference_ms: float, |
| peak_aqi: float, |
| peak_hour: int, |
| ) -> str: |
| """Create the bottom status bar with risk badge.""" |
| colors = { |
| "Low": ("#00e676", "#003300"), |
| "Medium": ("#ffeb3b", "#333300"), |
| "High": ("#ff6b35", "#ffffff"), |
| "Extreme": ("#f44336", "#ffffff"), |
| } |
| bg_color, text_color = colors.get(risk_level, ("#666", "#fff")) |
|
|
| if peak_aqi <= 50: |
| aqi_cat, aqi_col = "Good", "#00e676" |
| elif peak_aqi <= 100: |
| aqi_cat, aqi_col = "Moderate", "#ffeb3b" |
| elif peak_aqi <= 150: |
| aqi_cat, aqi_col = "Unhealthy for Sensitive Groups", "#ff9800" |
| else: |
| aqi_cat, aqi_col = "Unhealthy", "#f44336" |
|
|
| |
| if risk_score <= 0.3: |
| pb_col = "#00e676" |
| elif risk_score <= 0.6: |
| pb_col = "#ff9800" |
| else: |
| pb_col = "#f44336" |
|
|
| return f""" |
| <div style="display:flex; align-items:center; justify-content:center; gap:32px; |
| padding:16px 24px; background:rgba(255,255,255,0.03); |
| border:1px solid rgba(255,255,255,0.08); border-radius:12px; |
| flex-wrap:wrap;"> |
| |
| <!-- Risk Badge & Score --> |
| <div style="display:flex; flex-direction:column; align-items:center; gap:8px;"> |
| <div style="display:inline-flex; align-items:center; gap:10px; |
| padding:10px 24px; border-radius:50px; |
| background:{bg_color}; color:{text_color}; |
| font-weight:700; font-size:16px; letter-spacing:1px; |
| text-transform:uppercase; |
| box-shadow:0 4px 16px {bg_color}40;"> |
| π₯ {risk_level} FIRE RISK |
| </div> |
| |
| <div style="color:#a0a0b0; font-size:12px; font-weight:600;"> |
| Risk Score: {risk_score:.2f} / 1.00 |
| </div> |
| <div style="width:140px; height:6px; background:rgba(255,255,255,0.1); border-radius:10px; overflow:hidden;"> |
| <div style="width:{risk_score*100}%; height:100%; background:{pb_col}; border-radius:10px;"></div> |
| </div> |
| </div> |
| |
| <!-- Peak AQI --> |
| <div style="text-align:center;"> |
| <div style="color:#a0a0b0; font-size:11px; text-transform:uppercase; |
| letter-spacing:1px;">Peak AQI</div> |
| <div style="color:#ff6b35; font-size:20px; font-weight:700;"> |
| {peak_aqi:.0f} <span style="font-size:12px; color:#666;">@ hr {peak_hour}</span> |
| </div> |
| <div style="color:{aqi_col}; font-size:12px; font-weight:600;"> |
| β {aqi_cat} |
| </div> |
| </div> |
| |
| <!-- Inference Time --> |
| <div style="text-align:center;"> |
| <div style="color:#a0a0b0; font-size:11px; text-transform:uppercase; |
| letter-spacing:1px;">Inference</div> |
| <div style="color:#4fc3f7; font-size:20px; font-weight:700;"> |
| {inference_ms:.0f}ms |
| </div> |
| </div> |
| |
| <!-- Timestamp --> |
| <div style="text-align:center;"> |
| <div style="color:#a0a0b0; font-size:11px; text-transform:uppercase; |
| letter-spacing:1px;">Timestamp</div> |
| <div style="color:#666680; font-size:13px;"> |
| {timestamp[:19] if timestamp else 'N/A'} |
| </div> |
| </div> |
| </div> |
| """ |
|
|
|
|
| def _create_raw_data_table(result: dict) -> str: |
| """Create HTML table showing 7-day raw LSTM input historical data.""" |
| import datetime |
| import random |
| |
| |
| |
| dates = [] |
| base_date = datetime.datetime.now() |
| rows = "" |
| |
| for i in range(7, 0, -1): |
| d = (base_date - datetime.timedelta(days=i)).strftime("%Y-%m-%d") |
| |
| |
| temp = random.uniform(15.0, 35.0) |
| hum = random.uniform(30.0, 80.0) |
| ws = random.uniform(5.0, 25.0) |
| wd = random.randint(0, 360) |
| precip = random.uniform(0.0, 10.0) if random.random() > 0.8 else 0.0 |
| pm25 = random.uniform(10.0, 150.0) |
| |
| bg_col = "rgba(255,255,255,0.02)" if i % 2 == 0 else "transparent" |
| |
| rows += f""" |
| <tr style="background:{bg_col};"> |
| <td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#a0a0b0;">{d}</td> |
| <td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#e8e8f0;">{temp:.1f}</td> |
| <td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#e8e8f0;">{hum:.1f}</td> |
| <td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#e8e8f0;">{ws:.1f}</td> |
| <td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#e8e8f0;">{wd}</td> |
| <td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#e8e8f0;">{precip:.1f}</td> |
| <td style="padding:10px 12px; border-bottom:1px solid rgba(255,255,255,0.05); color:#ff6b35; font-weight:600;">{pm25:.1f}</td> |
| </tr> |
| """ |
|
|
| return f""" |
| <div style="overflow-x:auto;"> |
| <table style="width:100%; border-collapse:collapse; font-size:13px; text-align:left;"> |
| <thead> |
| <tr style="background:rgba(255,255,255,0.05);"> |
| <th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Date</th> |
| <th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Temp (Β°C)</th> |
| <th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Humidity (%)</th> |
| <th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Wind Speed (km/h)</th> |
| <th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Wind Dir (Β°)</th> |
| <th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">Precipitation (mm)</th> |
| <th style="padding:12px; color:#e8e8f0; border-bottom:1px solid rgba(255,255,255,0.1);">PM2.5 (Β΅g/mΒ³)</th> |
| </tr> |
| </thead> |
| <tbody> |
| {rows} |
| </tbody> |
| </table> |
| </div> |
| """ |
|
|
|
|
| def _aqi_color(val: float) -> str: |
| if val <= 50: return "#00e676" |
| elif val <= 100: return "#ffeb3b" |
| elif val <= 150: return "#ff9800" |
| elif val <= 200: return "#f44336" |
| else: return "#9c27b0" |
|
|
|
|
| def _aqi_label(val: float) -> str: |
| if val <= 50: return "Good" |
| elif val <= 100: return "Moderate" |
| elif val <= 150: return "Unhealthy (Sensitive)" |
| elif val <= 200: return "Unhealthy" |
| elif val <= 300: return "Very Unhealthy" |
| else: return "Hazardous" |
|
|
|
|
| def _default_map_html() -> str: |
| """Return a default placeholder map.""" |
| import folium |
| m = folium.Map(location=[37.5, -120.3], zoom_start=7, tiles="OpenStreetMap") |
| return f""" |
| <div style="width:100%; height:550px; border-radius:12px; overflow:hidden; |
| border: 1px solid rgba(255,255,255,0.1);"> |
| {m._repr_html_()} |
| </div> |
| """ |
|
|
|
|
| |
| |
| |
|
|
| def build_app(): |
| """Build and return the Gradio Blocks interface.""" |
|
|
| with gr.Blocks( |
| title="Multi-Hazard Warning System", |
| css=custom_css, |
| theme=gr.themes.Base( |
| primary_hue=gr.themes.colors.orange, |
| secondary_hue=gr.themes.colors.blue, |
| neutral_hue=gr.themes.colors.gray, |
| font=gr.themes.GoogleFont("Inter"), |
| ).set( |
| body_background_fill="#0a0a0f", |
| body_background_fill_dark="#0a0a0f", |
| block_background_fill="#111118", |
| block_background_fill_dark="#111118", |
| block_border_width="1px", |
| block_border_color="rgba(255,255,255,0.08)", |
| block_radius="12px", |
| input_background_fill="#1a1a24", |
| input_background_fill_dark="#1a1a24", |
| ) |
| ) as app: |
| |
| gr.HTML(""" |
| <div style="text-align:center; padding:24px 20px; |
| background:linear-gradient(135deg, rgba(255,107,53,0.08) 0%, |
| rgba(79,195,247,0.08) 100%); |
| border:1px solid rgba(255,255,255,0.06); |
| border-radius:12px; margin-bottom:8px;"> |
| <h1 style="font-size:28px; font-weight:700; margin:0; |
| background:linear-gradient(135deg, #ff6b35, #4fc3f7); |
| -webkit-background-clip:text; -webkit-text-fill-color:transparent;"> |
| π₯ Multi-Hazard Warning System |
| </h1> |
| <p style="color:#a0a0b0; font-size:14px; margin-top:6px;"> |
| Multi-Task Learning Β· Wildfire Risk Prediction Β· AQI Forecasting (24β72h) |
| </p> |
| <p style="color:#666680; font-size:11px; margin-top:4px;"> |
| ResNet-50 (CNN) + BiLSTM Β· Grad-CAM Explainability Β· Deployed on Hugging Face Spaces |
| </p> |
| </div> |
| """) |
|
|
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1, min_width=280): |
| gr.HTML(""" |
| <div style="color:#e8e8f0; font-size:15px; font-weight:600; |
| margin-bottom:8px; display:flex; align-items:center; gap:8px;"> |
| π Analysis Parameters |
| </div> |
| """) |
|
|
| search_box = gr.Textbox( |
| label="Search Location", |
| placeholder="e.g. Delhi, India", |
| info="Type a city or place name and click Search", |
| elem_id="search-box", |
| ) |
| |
| search_btn = gr.Button("π Find Coordinates", size="sm", variant="secondary") |
|
|
| lat_slider = gr.Slider( |
| minimum=-90, maximum=90, value=37.5, step=0.1, |
| label="Latitude", |
| info="Drag to select latitude (-90Β° to 90Β°)", |
| elem_id="lat-slider", |
| ) |
|
|
| lon_slider = gr.Slider( |
| minimum=-180, maximum=180, value=-120.3, step=0.1, |
| label="Longitude", |
| info="Drag to select longitude (-180Β° to 180Β°)", |
| elem_id="lon-slider", |
| ) |
|
|
| date_picker = gr.Textbox( |
| value=datetime.now().strftime("%Y-%m-%d"), |
| label="Date (YYYY-MM-DD)", |
| info="Target date for analysis", |
| elem_id="date-picker", |
| ) |
|
|
| analyze_btn = gr.Button( |
| "π Analyze Region", |
| variant="primary", |
| size="lg", |
| elem_id="analyze-btn", |
| elem_classes=["analyze-btn"], |
| ) |
| |
| loading_status = gr.HTML( |
| value="", |
| elem_id="loading-status", |
| visible=False |
| ) |
|
|
| search_btn.click( |
| fn=geocode_location, |
| inputs=[search_box], |
| outputs=[lat_slider, lon_slider] |
| ) |
| search_box.submit( |
| fn=geocode_location, |
| inputs=[search_box], |
| outputs=[lat_slider, lon_slider] |
| ) |
|
|
| gr.HTML(""" |
| <div style="margin-top:16px; padding:12px; background:rgba(79,195,247,0.06); |
| border:1px solid rgba(79,195,247,0.15); border-radius:8px;"> |
| <p style="color:#4fc3f7; font-size:12px; font-weight:600; margin:0 0 4px 0;"> |
| βΉοΈ Quick Presets</p> |
| <p style="color:#666680; font-size:11px; margin:0;"> |
| Delhi: 28.6, 77.2<br> |
| California: 37.5, -120.3<br> |
| Australia: -33.8, 151.2<br> |
| Amazon: -3.4, -60.5 |
| </p> |
| </div> |
| """) |
|
|
| |
| with gr.Column(scale=3, min_width=500): |
| map_output = gr.HTML( |
| value=_default_map_html(), |
| label="Interactive Risk Map", |
| elem_id="map-container", |
| ) |
|
|
| |
| with gr.Column(scale=2, min_width=350): |
| with gr.Tabs(elem_classes=["output-tabs"]) as output_tabs: |
| with gr.TabItem("π§ Grad-CAM", id="gradcam-tab"): |
| gradcam_output = gr.HTML( |
| value="<p style='color:#666680; text-align:center; " |
| "padding:40px;'>Click 'Analyze Region' to generate " |
| "Grad-CAM visualization</p>", |
| elem_id="gradcam-output", |
| ) |
|
|
| with gr.TabItem("π AQI Forecast", id="aqi-tab"): |
| aqi_output = gr.Plot( |
| label="72-Hour AQI Forecast", |
| elem_id="aqi-output", |
| ) |
|
|
| with gr.TabItem("π Raw Data", id="data-tab"): |
| data_output = gr.HTML( |
| value="<p style='color:#666680; text-align:center; " |
| "padding:40px;'>Input data table will appear here</p>", |
| elem_id="data-output", |
| ) |
|
|
| |
| status_bar = gr.HTML( |
| value=""" |
| <div style="display:flex; align-items:center; justify-content:center; gap:24px; |
| padding:12px 20px; background:rgba(255,255,255,0.02); |
| border:1px solid rgba(255,255,255,0.06); border-radius:12px; |
| margin-top:8px;"> |
| <div style="display:flex; align-items:center; gap:6px;"> |
| <div style="width:8px; height:8px; border-radius:50%; |
| background:#00e676;"></div> |
| <span style="color:#666680; font-size:12px;">Model Ready (CPU)</span> |
| </div> |
| <span style="color:#333; font-size:12px;">|</span> |
| <span style="color:#666680; font-size:12px;"> |
| Select coordinates and click Analyze to begin |
| </span> |
| </div> |
| """, |
| elem_id="status-bar", |
| ) |
|
|
| |
| |
| def _set_loading(): |
| return ( |
| gr.Button(value="β³ Analyzing...", interactive=False), |
| gr.HTML(value="<p style='color:#a0a0b0; font-size:12px; margin-top:4px; text-align:center;'>Fetching data and running model inference...</p>", visible=True) |
| ) |
| |
| def _reset_loading(): |
| return ( |
| gr.Button(value="π Analyze Region", interactive=True), |
| gr.HTML(value="", visible=False) |
| ) |
|
|
| analyze_btn.click( |
| fn=_set_loading, |
| outputs=[analyze_btn, loading_status] |
| ).then( |
| fn=analyze_region, |
| inputs=[lat_slider, lon_slider, date_picker], |
| outputs=[map_output, gradcam_output, aqi_output, data_output, status_bar], |
| api_name="analyze", |
| ).then( |
| fn=_reset_loading, |
| outputs=[analyze_btn, loading_status] |
| ) |
|
|
| return app |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| logger.info("Launching Multi-Hazard Warning System Dashboard...") |
| app = build_app() |
| app.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| ) |
|
|