| """ |
| visualize.py — Visualization utilities for predictions. |
| |
| Generates: |
| - Grad-CAM overlays on satellite images |
| - AQI forecast Plotly charts with WHO threshold lines |
| - Folium maps with multiple toggleable layers: |
| Layer 1: Fire risk heatmap overlay |
| Layer 2: Grad-CAM overlay as ImageOverlay |
| Layer 3: Active fire point markers with confidence popups |
| Layer 4: AQI station markers color-coded by level |
| """ |
|
|
| import base64 |
| import io |
| import logging |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import cv2 |
| import numpy as np |
|
|
| from src.training.config import ( |
| OUTPUTS_DIR, PATCH_SIZE, AQI_FORECAST_HOURS, AQI_CATEGORIES, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def create_aqi_forecast_chart( |
| aqi_forecast: List[float], |
| forecast_hours: Optional[List[int]] = None, |
| ): |
| """ |
| Create an AQI forecast Plotly line chart with WHO threshold lines. |
| |
| Thresholds: |
| - Good: 0–50 (green) |
| - Moderate: 51–100 (yellow) |
| - Unhealthy (Sensitive): 101–150 (orange) |
| |
| Returns: |
| Plotly figure as HTML string. |
| """ |
| import plotly.graph_objects as go |
|
|
| if forecast_hours is None: |
| forecast_hours = list(range(1, len(aqi_forecast) + 1)) |
|
|
| fig = go.Figure() |
|
|
| |
| fig.add_trace(go.Scatter( |
| x=forecast_hours, |
| y=aqi_forecast, |
| mode="lines+markers", |
| name="AQI Forecast", |
| line=dict(color="#FF6B35", width=3), |
| marker=dict(size=4), |
| fill="tozeroy", |
| fillcolor="rgba(255, 107, 53, 0.1)", |
| )) |
|
|
| |
| thresholds = [ |
| (50, "Good", "#00E676"), |
| (100, "Moderate", "#FFEB3B"), |
| (150, "Unhealthy (Sensitive)", "#FF9800"), |
| ] |
|
|
| for threshold, label, color in thresholds: |
| fig.add_hline( |
| y=threshold, |
| line_dash="dash", |
| line_color=color, |
| line_width=2, |
| annotation_text=f"{label} ({threshold})", |
| annotation_position="right", |
| annotation_font_size=10, |
| annotation_font_color=color, |
| ) |
|
|
| |
| fig.update_layout( |
| title=dict( |
| text="72-Hour AQI Forecast (PM2.5)", |
| font=dict(size=16, color="#E0E0E0"), |
| ), |
| xaxis_title="Forecast Hour", |
| yaxis_title="AQI (PM2.5)", |
| template="plotly_dark", |
| plot_bgcolor="rgba(17, 17, 17, 0.8)", |
| paper_bgcolor="rgba(17, 17, 17, 0.9)", |
| xaxis=dict( |
| gridcolor="rgba(255,255,255,0.1)", |
| range=[1, len(aqi_forecast)], |
| ), |
| yaxis=dict( |
| gridcolor="rgba(255,255,255,0.1)", |
| range=[0, max(max(aqi_forecast) * 1.2, 200)], |
| ), |
| height=400, |
| margin=dict(l=50, r=50, t=50, b=50), |
| showlegend=True, |
| legend=dict( |
| bgcolor="rgba(0,0,0,0.5)", |
| font=dict(color="#E0E0E0"), |
| ), |
| ) |
|
|
| return fig |
|
|
|
|
| def create_folium_map( |
| latitude: float, |
| longitude: float, |
| risk_map: np.ndarray, |
| risk_score: float, |
| aqi_forecast: List[float], |
| gradcam_b64: Optional[str] = None, |
| fire_points: Optional[List[Dict]] = None, |
| ) -> str: |
| """ |
| Create an interactive Folium map with multiple toggleable layers. |
| |
| Layers: |
| 1. Fire risk heatmap overlay (HeatMap plugin) |
| 2. Grad-CAM overlay (ImageOverlay) |
| 3. Active fire point markers with confidence popups |
| 4. AQI station markers color-coded by level |
| |
| Returns: |
| Complete HTML string for the Folium map. |
| """ |
| import folium |
| from folium.plugins import HeatMap |
|
|
| |
| m = folium.Map( |
| location=[latitude, longitude], |
| zoom_start=9, |
| tiles="OpenStreetMap", |
| control_scale=True, |
| ) |
|
|
| |
| folium.TileLayer( |
| tiles="https://server.arcgisonline.com/ArcGIS/rest/services/World_Topo_Map/MapServer/tile/{z}/{y}/{x}", |
| attr="Esri", |
| name="Topographic", |
| ).add_to(m) |
|
|
| |
| risk_layer = folium.FeatureGroup(name="🔥 Fire Risk Heatmap") |
|
|
| |
| heat_data = [] |
| h, w = risk_map.shape |
| lat_range = 0.5 |
| lon_range = 0.5 |
|
|
| for i in range(0, h, 4): |
| for j in range(0, w, 4): |
| if risk_map[i, j] > 0.2: |
| lat_point = latitude + lat_range * (0.5 - i / h) |
| lon_point = longitude + lon_range * (j / w - 0.5) |
| heat_data.append([lat_point, lon_point, float(risk_map[i, j])]) |
|
|
| if heat_data: |
| HeatMap( |
| heat_data, |
| radius=15, |
| blur=10, |
| max_zoom=13, |
| gradient={0.2: "blue", 0.4: "lime", 0.6: "yellow", 0.8: "orange", 1.0: "red"}, |
| ).add_to(risk_layer) |
|
|
| risk_layer.add_to(m) |
|
|
| |
| if gradcam_b64: |
| gradcam_layer = folium.FeatureGroup(name="🧠 Grad-CAM Overlay") |
|
|
| |
| try: |
| gradcam_bytes = base64.b64decode(gradcam_b64) |
| gradcam_arr = np.frombuffer(gradcam_bytes, dtype=np.uint8) |
| gradcam_img = cv2.imdecode(gradcam_arr, cv2.IMREAD_COLOR) |
|
|
| if gradcam_img is not None: |
| |
| _, buffer = cv2.imencode(".png", gradcam_img) |
| img_b64 = base64.b64encode(buffer).decode("utf-8") |
| data_uri = f"data:image/png;base64,{img_b64}" |
|
|
| bounds = [ |
| [latitude - lat_range / 2, longitude - lon_range / 2], |
| [latitude + lat_range / 2, longitude + lon_range / 2], |
| ] |
|
|
| folium.raster_layers.ImageOverlay( |
| image=data_uri, |
| bounds=bounds, |
| opacity=0.5, |
| name="Grad-CAM", |
| ).add_to(gradcam_layer) |
| except Exception as e: |
| logger.warning(f"Failed to add Grad-CAM overlay: {e}") |
|
|
| gradcam_layer.add_to(m) |
|
|
| |
| fire_layer = folium.FeatureGroup(name="📍 Active Fire Detections") |
|
|
| if fire_points: |
| for pt in fire_points: |
| confidence = pt.get("confidence", "N/A") |
| folium.CircleMarker( |
| location=[pt["latitude"], pt["longitude"]], |
| radius=6, |
| color="red", |
| fill=True, |
| fillColor="red", |
| fillOpacity=0.8, |
| popup=folium.Popup( |
| f"<b>Fire Detection</b><br>" |
| f"Confidence: {confidence}%<br>" |
| f"Lat: {pt['latitude']:.4f}<br>" |
| f"Lon: {pt['longitude']:.4f}", |
| max_width=200, |
| ), |
| ).add_to(fire_layer) |
| else: |
| |
| np.random.seed(42) |
| num_points = int(risk_score * 10) + 2 |
| for _ in range(num_points): |
| pt_lat = latitude + np.random.uniform(-0.2, 0.2) |
| pt_lon = longitude + np.random.uniform(-0.2, 0.2) |
| conf = np.random.randint(60, 100) |
|
|
| folium.CircleMarker( |
| location=[pt_lat, pt_lon], |
| radius=6, |
| color="red", |
| fill=True, |
| fillColor="#FF4444", |
| fillOpacity=0.8, |
| popup=folium.Popup( |
| f"<b>🔥 Fire Detection</b><br>" |
| f"Confidence: {conf}%<br>" |
| f"Lat: {pt_lat:.4f}<br>" |
| f"Lon: {pt_lon:.4f}", |
| max_width=200, |
| ), |
| tooltip=f"Fire ({conf}%)", |
| ).add_to(fire_layer) |
|
|
| fire_layer.add_to(m) |
|
|
| |
| aqi_layer = folium.FeatureGroup(name="🌬️ AQI Stations") |
|
|
| current_aqi = aqi_forecast[0] if aqi_forecast else 50 |
| aqi_color = _get_aqi_color(current_aqi) |
| aqi_label = _get_aqi_label(current_aqi) |
|
|
| |
| folium.Marker( |
| location=[latitude, longitude], |
| icon=folium.DivIcon( |
| html=f""" |
| <div style=" |
| background: {aqi_color}; |
| color: white; |
| border-radius: 50%; |
| width: 36px; |
| height: 36px; |
| display: flex; |
| align-items: center; |
| justify-content: center; |
| font-weight: bold; |
| font-size: 11px; |
| border: 2px solid white; |
| box-shadow: 0 2px 6px rgba(0,0,0,0.5); |
| ">{int(current_aqi)}</div> |
| """, |
| icon_size=(36, 36), |
| icon_anchor=(18, 18), |
| ), |
| popup=folium.Popup( |
| f"<b>🌬️ AQI Station</b><br>" |
| f"Current AQI: {current_aqi:.0f}<br>" |
| f"Level: {aqi_label}<br>" |
| f"Peak (72h): {max(aqi_forecast):.0f}", |
| max_width=200, |
| ), |
| tooltip=f"AQI: {current_aqi:.0f} ({aqi_label})", |
| ).add_to(aqi_layer) |
|
|
| |
| np.random.seed(123) |
| for i in range(3): |
| st_lat = latitude + np.random.uniform(-0.3, 0.3) |
| st_lon = longitude + np.random.uniform(-0.3, 0.3) |
| st_aqi = current_aqi + np.random.uniform(-20, 30) |
| st_color = _get_aqi_color(st_aqi) |
|
|
| folium.CircleMarker( |
| location=[st_lat, st_lon], |
| radius=10, |
| color=st_color, |
| fill=True, |
| fillColor=st_color, |
| fillOpacity=0.7, |
| popup=f"AQI: {st_aqi:.0f}", |
| tooltip=f"AQI: {st_aqi:.0f}", |
| ).add_to(aqi_layer) |
|
|
| aqi_layer.add_to(m) |
|
|
| |
| folium.Marker( |
| location=[latitude, longitude], |
| icon=folium.Icon(color="blue", icon="info-sign"), |
| popup=f"<b>Analysis Center</b><br>" |
| f"Risk: {risk_score:.2f}<br>" |
| f"Lat: {latitude}, Lon: {longitude}", |
| ).add_to(m) |
|
|
| |
| folium.LayerControl(collapsed=False).add_to(m) |
|
|
| return m._repr_html_() |
|
|
|
|
| def _get_aqi_color(aqi: float) -> str: |
| """Map AQI value to display color.""" |
| if aqi <= 50: |
| return "#00E676" |
| elif aqi <= 100: |
| return "#FFEB3B" |
| elif aqi <= 150: |
| return "#FF9800" |
| elif aqi <= 200: |
| return "#F44336" |
| elif aqi <= 300: |
| return "#9C27B0" |
| else: |
| return "#7B1FA2" |
|
|
|
|
| def _get_aqi_label(aqi: float) -> str: |
| """Map AQI value to category label.""" |
| for label, (low, high) in AQI_CATEGORIES.items(): |
| if low <= aqi <= high: |
| return label |
| return "Hazardous" |
|
|
|
|
| def save_folium_map(html: str, filename: str = "prediction_map.html"): |
| """Save Folium map HTML to outputs directory.""" |
| path = OUTPUTS_DIR / filename |
| with open(path, "w", encoding="utf-8") as f: |
| f.write(html) |
| logger.info(f"Folium map saved: {path}") |
| return str(path) |
|
|
|
|
| def create_satellite_comparison( |
| original_image: np.ndarray, |
| gradcam_overlay: np.ndarray, |
| ) -> str: |
| """ |
| Create side-by-side satellite image vs Grad-CAM overlay. |
| |
| Returns: |
| Base64-encoded PNG of the comparison image. |
| """ |
| |
| if original_image.shape[0] <= 4: |
| img_rgb = np.transpose(original_image[:3], (1, 2, 0)) |
| else: |
| img_rgb = original_image[:, :, :3] |
|
|
| img_min, img_max = img_rgb.min(), img_rgb.max() |
| if img_max - img_min > 0: |
| img_rgb = (img_rgb - img_min) / (img_max - img_min) |
| img_uint8 = np.uint8(255 * img_rgb) |
|
|
| |
| h, w = img_uint8.shape[:2] |
| if gradcam_overlay.shape[:2] != (h, w): |
| gradcam_overlay = cv2.resize(gradcam_overlay, (w, h)) |
|
|
| |
| separator = np.ones((h, 4, 3), dtype=np.uint8) * 128 |
| composite = np.hstack([ |
| cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR), |
| separator, |
| gradcam_overlay, |
| ]) |
|
|
| |
| font = cv2.FONT_HERSHEY_SIMPLEX |
| cv2.putText(composite, "Satellite", (10, 20), font, 0.5, (255, 255, 255), 1) |
| cv2.putText(composite, "Grad-CAM", (w + 14, 20), font, 0.5, (255, 255, 255), 1) |
|
|
| _, buffer = cv2.imencode(".png", composite) |
| return base64.b64encode(buffer).decode("utf-8") |
|
|
|
|
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.INFO) |
|
|
| |
| forecast = list(np.random.rand(72) * 100 + 30) |
| chart_html = create_aqi_forecast_chart(forecast) |
| print(f"AQI chart HTML length: {len(chart_html)}") |
|
|
| |
| risk_map = np.random.rand(128, 128).astype(np.float32) |
| map_html = create_folium_map( |
| latitude=37.5, |
| longitude=-120.3, |
| risk_map=risk_map, |
| risk_score=0.65, |
| aqi_forecast=forecast, |
| ) |
| print(f"Folium map HTML length: {len(map_html)}") |
| save_folium_map(map_html) |
|
|