"""
visualize.py — Visualization utilities for predictions.
Generates:
- Grad-CAM overlays on satellite images
- AQI forecast Plotly charts with WHO threshold lines
- Folium maps with multiple toggleable layers:
Layer 1: Fire risk heatmap overlay
Layer 2: Grad-CAM overlay as ImageOverlay
Layer 3: Active fire point markers with confidence popups
Layer 4: AQI station markers color-coded by level
"""
import base64
import io
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import cv2
import numpy as np
from src.training.config import (
OUTPUTS_DIR, PATCH_SIZE, AQI_FORECAST_HOURS, AQI_CATEGORIES,
)
logger = logging.getLogger(__name__)
def create_aqi_forecast_chart(
aqi_forecast: List[float],
forecast_hours: Optional[List[int]] = None,
):
"""
Create an AQI forecast Plotly line chart with WHO threshold lines.
Thresholds:
- Good: 0–50 (green)
- Moderate: 51–100 (yellow)
- Unhealthy (Sensitive): 101–150 (orange)
Returns:
Plotly figure as HTML string.
"""
import plotly.graph_objects as go
if forecast_hours is None:
forecast_hours = list(range(1, len(aqi_forecast) + 1))
fig = go.Figure()
# Main forecast line
fig.add_trace(go.Scatter(
x=forecast_hours,
y=aqi_forecast,
mode="lines+markers",
name="AQI Forecast",
line=dict(color="#FF6B35", width=3),
marker=dict(size=4),
fill="tozeroy",
fillcolor="rgba(255, 107, 53, 0.1)",
))
# WHO threshold lines
thresholds = [
(50, "Good", "#00E676"),
(100, "Moderate", "#FFEB3B"),
(150, "Unhealthy (Sensitive)", "#FF9800"),
]
for threshold, label, color in thresholds:
fig.add_hline(
y=threshold,
line_dash="dash",
line_color=color,
line_width=2,
annotation_text=f"{label} ({threshold})",
annotation_position="right",
annotation_font_size=10,
annotation_font_color=color,
)
# Styling
fig.update_layout(
title=dict(
text="72-Hour AQI Forecast (PM2.5)",
font=dict(size=16, color="#E0E0E0"),
),
xaxis_title="Forecast Hour",
yaxis_title="AQI (PM2.5)",
template="plotly_dark",
plot_bgcolor="rgba(17, 17, 17, 0.8)",
paper_bgcolor="rgba(17, 17, 17, 0.9)",
xaxis=dict(
gridcolor="rgba(255,255,255,0.1)",
range=[1, len(aqi_forecast)],
),
yaxis=dict(
gridcolor="rgba(255,255,255,0.1)",
range=[0, max(max(aqi_forecast) * 1.2, 200)],
),
height=400,
margin=dict(l=50, r=50, t=50, b=50),
showlegend=True,
legend=dict(
bgcolor="rgba(0,0,0,0.5)",
font=dict(color="#E0E0E0"),
),
)
return fig
def create_folium_map(
latitude: float,
longitude: float,
risk_map: np.ndarray,
risk_score: float,
aqi_forecast: List[float],
gradcam_b64: Optional[str] = None,
fire_points: Optional[List[Dict]] = None,
) -> str:
"""
Create an interactive Folium map with multiple toggleable layers.
Layers:
1. Fire risk heatmap overlay (HeatMap plugin)
2. Grad-CAM overlay (ImageOverlay)
3. Active fire point markers with confidence popups
4. AQI station markers color-coded by level
Returns:
Complete HTML string for the Folium map.
"""
import folium
from folium.plugins import HeatMap
# Create base map centered on location
m = folium.Map(
location=[latitude, longitude],
zoom_start=9,
tiles="OpenStreetMap",
control_scale=True,
)
# Add Stamen Terrain as alternative base layer
folium.TileLayer(
tiles="https://server.arcgisonline.com/ArcGIS/rest/services/World_Topo_Map/MapServer/tile/{z}/{y}/{x}",
attr="Esri",
name="Topographic",
).add_to(m)
# --- Layer 1: Fire Risk Heatmap ---
risk_layer = folium.FeatureGroup(name="🔥 Fire Risk Heatmap")
# Generate heatmap points from risk map
heat_data = []
h, w = risk_map.shape
lat_range = 0.5 # degrees of lat/lon to cover
lon_range = 0.5
for i in range(0, h, 4): # Sample every 4th pixel for performance
for j in range(0, w, 4):
if risk_map[i, j] > 0.2: # Only show significant risk
lat_point = latitude + lat_range * (0.5 - i / h)
lon_point = longitude + lon_range * (j / w - 0.5)
heat_data.append([lat_point, lon_point, float(risk_map[i, j])])
if heat_data:
HeatMap(
heat_data,
radius=15,
blur=10,
max_zoom=13,
gradient={0.2: "blue", 0.4: "lime", 0.6: "yellow", 0.8: "orange", 1.0: "red"},
).add_to(risk_layer)
risk_layer.add_to(m)
# --- Layer 2: Grad-CAM Overlay ---
if gradcam_b64:
gradcam_layer = folium.FeatureGroup(name="🧠 Grad-CAM Overlay")
# Decode base64 to temporary image
try:
gradcam_bytes = base64.b64decode(gradcam_b64)
gradcam_arr = np.frombuffer(gradcam_bytes, dtype=np.uint8)
gradcam_img = cv2.imdecode(gradcam_arr, cv2.IMREAD_COLOR)
if gradcam_img is not None:
# Convert to PNG data URI
_, buffer = cv2.imencode(".png", gradcam_img)
img_b64 = base64.b64encode(buffer).decode("utf-8")
data_uri = f"data:image/png;base64,{img_b64}"
bounds = [
[latitude - lat_range / 2, longitude - lon_range / 2],
[latitude + lat_range / 2, longitude + lon_range / 2],
]
folium.raster_layers.ImageOverlay(
image=data_uri,
bounds=bounds,
opacity=0.5,
name="Grad-CAM",
).add_to(gradcam_layer)
except Exception as e:
logger.warning(f"Failed to add Grad-CAM overlay: {e}")
gradcam_layer.add_to(m)
# --- Layer 3: Active Fire Point Markers ---
fire_layer = folium.FeatureGroup(name="📍 Active Fire Detections")
if fire_points:
for pt in fire_points:
confidence = pt.get("confidence", "N/A")
folium.CircleMarker(
location=[pt["latitude"], pt["longitude"]],
radius=6,
color="red",
fill=True,
fillColor="red",
fillOpacity=0.8,
popup=folium.Popup(
f"Fire Detection
"
f"Confidence: {confidence}%
"
f"Lat: {pt['latitude']:.4f}
"
f"Lon: {pt['longitude']:.4f}",
max_width=200,
),
).add_to(fire_layer)
else:
# Generate synthetic fire points for demo
np.random.seed(42)
num_points = int(risk_score * 10) + 2
for _ in range(num_points):
pt_lat = latitude + np.random.uniform(-0.2, 0.2)
pt_lon = longitude + np.random.uniform(-0.2, 0.2)
conf = np.random.randint(60, 100)
folium.CircleMarker(
location=[pt_lat, pt_lon],
radius=6,
color="red",
fill=True,
fillColor="#FF4444",
fillOpacity=0.8,
popup=folium.Popup(
f"🔥 Fire Detection
"
f"Confidence: {conf}%
"
f"Lat: {pt_lat:.4f}
"
f"Lon: {pt_lon:.4f}",
max_width=200,
),
tooltip=f"Fire ({conf}%)",
).add_to(fire_layer)
fire_layer.add_to(m)
# --- Layer 4: AQI Station Markers ---
aqi_layer = folium.FeatureGroup(name="🌬️ AQI Stations")
current_aqi = aqi_forecast[0] if aqi_forecast else 50
aqi_color = _get_aqi_color(current_aqi)
aqi_label = _get_aqi_label(current_aqi)
# Main station marker
folium.Marker(
location=[latitude, longitude],
icon=folium.DivIcon(
html=f"""