""" visualize.py — Visualization utilities for predictions. Generates: - Grad-CAM overlays on satellite images - AQI forecast Plotly charts with WHO threshold lines - Folium maps with multiple toggleable layers: Layer 1: Fire risk heatmap overlay Layer 2: Grad-CAM overlay as ImageOverlay Layer 3: Active fire point markers with confidence popups Layer 4: AQI station markers color-coded by level """ import base64 import io import logging from pathlib import Path from typing import Dict, List, Optional, Tuple import cv2 import numpy as np from src.training.config import ( OUTPUTS_DIR, PATCH_SIZE, AQI_FORECAST_HOURS, AQI_CATEGORIES, ) logger = logging.getLogger(__name__) def create_aqi_forecast_chart( aqi_forecast: List[float], forecast_hours: Optional[List[int]] = None, ): """ Create an AQI forecast Plotly line chart with WHO threshold lines. Thresholds: - Good: 0–50 (green) - Moderate: 51–100 (yellow) - Unhealthy (Sensitive): 101–150 (orange) Returns: Plotly figure as HTML string. """ import plotly.graph_objects as go if forecast_hours is None: forecast_hours = list(range(1, len(aqi_forecast) + 1)) fig = go.Figure() # Main forecast line fig.add_trace(go.Scatter( x=forecast_hours, y=aqi_forecast, mode="lines+markers", name="AQI Forecast", line=dict(color="#FF6B35", width=3), marker=dict(size=4), fill="tozeroy", fillcolor="rgba(255, 107, 53, 0.1)", )) # WHO threshold lines thresholds = [ (50, "Good", "#00E676"), (100, "Moderate", "#FFEB3B"), (150, "Unhealthy (Sensitive)", "#FF9800"), ] for threshold, label, color in thresholds: fig.add_hline( y=threshold, line_dash="dash", line_color=color, line_width=2, annotation_text=f"{label} ({threshold})", annotation_position="right", annotation_font_size=10, annotation_font_color=color, ) # Styling fig.update_layout( title=dict( text="72-Hour AQI Forecast (PM2.5)", font=dict(size=16, color="#E0E0E0"), ), xaxis_title="Forecast Hour", yaxis_title="AQI (PM2.5)", template="plotly_dark", plot_bgcolor="rgba(17, 17, 17, 0.8)", paper_bgcolor="rgba(17, 17, 17, 0.9)", xaxis=dict( gridcolor="rgba(255,255,255,0.1)", range=[1, len(aqi_forecast)], ), yaxis=dict( gridcolor="rgba(255,255,255,0.1)", range=[0, max(max(aqi_forecast) * 1.2, 200)], ), height=400, margin=dict(l=50, r=50, t=50, b=50), showlegend=True, legend=dict( bgcolor="rgba(0,0,0,0.5)", font=dict(color="#E0E0E0"), ), ) return fig def create_folium_map( latitude: float, longitude: float, risk_map: np.ndarray, risk_score: float, aqi_forecast: List[float], gradcam_b64: Optional[str] = None, fire_points: Optional[List[Dict]] = None, ) -> str: """ Create an interactive Folium map with multiple toggleable layers. Layers: 1. Fire risk heatmap overlay (HeatMap plugin) 2. Grad-CAM overlay (ImageOverlay) 3. Active fire point markers with confidence popups 4. AQI station markers color-coded by level Returns: Complete HTML string for the Folium map. """ import folium from folium.plugins import HeatMap # Create base map centered on location m = folium.Map( location=[latitude, longitude], zoom_start=9, tiles="OpenStreetMap", control_scale=True, ) # Add Stamen Terrain as alternative base layer folium.TileLayer( tiles="https://server.arcgisonline.com/ArcGIS/rest/services/World_Topo_Map/MapServer/tile/{z}/{y}/{x}", attr="Esri", name="Topographic", ).add_to(m) # --- Layer 1: Fire Risk Heatmap --- risk_layer = folium.FeatureGroup(name="🔥 Fire Risk Heatmap") # Generate heatmap points from risk map heat_data = [] h, w = risk_map.shape lat_range = 0.5 # degrees of lat/lon to cover lon_range = 0.5 for i in range(0, h, 4): # Sample every 4th pixel for performance for j in range(0, w, 4): if risk_map[i, j] > 0.2: # Only show significant risk lat_point = latitude + lat_range * (0.5 - i / h) lon_point = longitude + lon_range * (j / w - 0.5) heat_data.append([lat_point, lon_point, float(risk_map[i, j])]) if heat_data: HeatMap( heat_data, radius=15, blur=10, max_zoom=13, gradient={0.2: "blue", 0.4: "lime", 0.6: "yellow", 0.8: "orange", 1.0: "red"}, ).add_to(risk_layer) risk_layer.add_to(m) # --- Layer 2: Grad-CAM Overlay --- if gradcam_b64: gradcam_layer = folium.FeatureGroup(name="🧠 Grad-CAM Overlay") # Decode base64 to temporary image try: gradcam_bytes = base64.b64decode(gradcam_b64) gradcam_arr = np.frombuffer(gradcam_bytes, dtype=np.uint8) gradcam_img = cv2.imdecode(gradcam_arr, cv2.IMREAD_COLOR) if gradcam_img is not None: # Convert to PNG data URI _, buffer = cv2.imencode(".png", gradcam_img) img_b64 = base64.b64encode(buffer).decode("utf-8") data_uri = f"data:image/png;base64,{img_b64}" bounds = [ [latitude - lat_range / 2, longitude - lon_range / 2], [latitude + lat_range / 2, longitude + lon_range / 2], ] folium.raster_layers.ImageOverlay( image=data_uri, bounds=bounds, opacity=0.5, name="Grad-CAM", ).add_to(gradcam_layer) except Exception as e: logger.warning(f"Failed to add Grad-CAM overlay: {e}") gradcam_layer.add_to(m) # --- Layer 3: Active Fire Point Markers --- fire_layer = folium.FeatureGroup(name="📍 Active Fire Detections") if fire_points: for pt in fire_points: confidence = pt.get("confidence", "N/A") folium.CircleMarker( location=[pt["latitude"], pt["longitude"]], radius=6, color="red", fill=True, fillColor="red", fillOpacity=0.8, popup=folium.Popup( f"Fire Detection
" f"Confidence: {confidence}%
" f"Lat: {pt['latitude']:.4f}
" f"Lon: {pt['longitude']:.4f}", max_width=200, ), ).add_to(fire_layer) else: # Generate synthetic fire points for demo np.random.seed(42) num_points = int(risk_score * 10) + 2 for _ in range(num_points): pt_lat = latitude + np.random.uniform(-0.2, 0.2) pt_lon = longitude + np.random.uniform(-0.2, 0.2) conf = np.random.randint(60, 100) folium.CircleMarker( location=[pt_lat, pt_lon], radius=6, color="red", fill=True, fillColor="#FF4444", fillOpacity=0.8, popup=folium.Popup( f"🔥 Fire Detection
" f"Confidence: {conf}%
" f"Lat: {pt_lat:.4f}
" f"Lon: {pt_lon:.4f}", max_width=200, ), tooltip=f"Fire ({conf}%)", ).add_to(fire_layer) fire_layer.add_to(m) # --- Layer 4: AQI Station Markers --- aqi_layer = folium.FeatureGroup(name="🌬️ AQI Stations") current_aqi = aqi_forecast[0] if aqi_forecast else 50 aqi_color = _get_aqi_color(current_aqi) aqi_label = _get_aqi_label(current_aqi) # Main station marker folium.Marker( location=[latitude, longitude], icon=folium.DivIcon( html=f"""
{int(current_aqi)}
""", icon_size=(36, 36), icon_anchor=(18, 18), ), popup=folium.Popup( f"🌬️ AQI Station
" f"Current AQI: {current_aqi:.0f}
" f"Level: {aqi_label}
" f"Peak (72h): {max(aqi_forecast):.0f}", max_width=200, ), tooltip=f"AQI: {current_aqi:.0f} ({aqi_label})", ).add_to(aqi_layer) # Nearby synthetic AQI stations np.random.seed(123) for i in range(3): st_lat = latitude + np.random.uniform(-0.3, 0.3) st_lon = longitude + np.random.uniform(-0.3, 0.3) st_aqi = current_aqi + np.random.uniform(-20, 30) st_color = _get_aqi_color(st_aqi) folium.CircleMarker( location=[st_lat, st_lon], radius=10, color=st_color, fill=True, fillColor=st_color, fillOpacity=0.7, popup=f"AQI: {st_aqi:.0f}", tooltip=f"AQI: {st_aqi:.0f}", ).add_to(aqi_layer) aqi_layer.add_to(m) # --- Center marker --- folium.Marker( location=[latitude, longitude], icon=folium.Icon(color="blue", icon="info-sign"), popup=f"Analysis Center
" f"Risk: {risk_score:.2f}
" f"Lat: {latitude}, Lon: {longitude}", ).add_to(m) # --- Layer control --- folium.LayerControl(collapsed=False).add_to(m) return m._repr_html_() def _get_aqi_color(aqi: float) -> str: """Map AQI value to display color.""" if aqi <= 50: return "#00E676" # Green elif aqi <= 100: return "#FFEB3B" # Yellow elif aqi <= 150: return "#FF9800" # Orange elif aqi <= 200: return "#F44336" # Red elif aqi <= 300: return "#9C27B0" # Purple else: return "#7B1FA2" # Dark red def _get_aqi_label(aqi: float) -> str: """Map AQI value to category label.""" for label, (low, high) in AQI_CATEGORIES.items(): if low <= aqi <= high: return label return "Hazardous" def save_folium_map(html: str, filename: str = "prediction_map.html"): """Save Folium map HTML to outputs directory.""" path = OUTPUTS_DIR / filename with open(path, "w", encoding="utf-8") as f: f.write(html) logger.info(f"Folium map saved: {path}") return str(path) def create_satellite_comparison( original_image: np.ndarray, gradcam_overlay: np.ndarray, ) -> str: """ Create side-by-side satellite image vs Grad-CAM overlay. Returns: Base64-encoded PNG of the comparison image. """ # Normalize original image to displayable range if original_image.shape[0] <= 4: # CHW format img_rgb = np.transpose(original_image[:3], (1, 2, 0)) # HWC else: img_rgb = original_image[:, :, :3] img_min, img_max = img_rgb.min(), img_rgb.max() if img_max - img_min > 0: img_rgb = (img_rgb - img_min) / (img_max - img_min) img_uint8 = np.uint8(255 * img_rgb) # Ensure same size h, w = img_uint8.shape[:2] if gradcam_overlay.shape[:2] != (h, w): gradcam_overlay = cv2.resize(gradcam_overlay, (w, h)) # Create side-by-side composite separator = np.ones((h, 4, 3), dtype=np.uint8) * 128 composite = np.hstack([ cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR), separator, gradcam_overlay, ]) # Add labels font = cv2.FONT_HERSHEY_SIMPLEX cv2.putText(composite, "Satellite", (10, 20), font, 0.5, (255, 255, 255), 1) cv2.putText(composite, "Grad-CAM", (w + 14, 20), font, 0.5, (255, 255, 255), 1) _, buffer = cv2.imencode(".png", composite) return base64.b64encode(buffer).decode("utf-8") if __name__ == "__main__": logging.basicConfig(level=logging.INFO) # Test AQI chart forecast = list(np.random.rand(72) * 100 + 30) chart_html = create_aqi_forecast_chart(forecast) print(f"AQI chart HTML length: {len(chart_html)}") # Test Folium map risk_map = np.random.rand(128, 128).astype(np.float32) map_html = create_folium_map( latitude=37.5, longitude=-120.3, risk_map=risk_map, risk_score=0.65, aqi_forecast=forecast, ) print(f"Folium map HTML length: {len(map_html)}") save_folium_map(map_html)