krupal02's picture
Deploy Multi-Hazard Warning System - MTL model for wildfire risk + AQI forecasting
d5b0af1
Raw
History Blame Contribute Delete
13.2 kB
"""
visualize.py — Visualization utilities for predictions.
Generates:
- Grad-CAM overlays on satellite images
- AQI forecast Plotly charts with WHO threshold lines
- Folium maps with multiple toggleable layers:
Layer 1: Fire risk heatmap overlay
Layer 2: Grad-CAM overlay as ImageOverlay
Layer 3: Active fire point markers with confidence popups
Layer 4: AQI station markers color-coded by level
"""
import base64
import io
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import cv2
import numpy as np
from src.training.config import (
OUTPUTS_DIR, PATCH_SIZE, AQI_FORECAST_HOURS, AQI_CATEGORIES,
)
logger = logging.getLogger(__name__)
def create_aqi_forecast_chart(
aqi_forecast: List[float],
forecast_hours: Optional[List[int]] = None,
):
"""
Create an AQI forecast Plotly line chart with WHO threshold lines.
Thresholds:
- Good: 0–50 (green)
- Moderate: 51–100 (yellow)
- Unhealthy (Sensitive): 101–150 (orange)
Returns:
Plotly figure as HTML string.
"""
import plotly.graph_objects as go
if forecast_hours is None:
forecast_hours = list(range(1, len(aqi_forecast) + 1))
fig = go.Figure()
# Main forecast line
fig.add_trace(go.Scatter(
x=forecast_hours,
y=aqi_forecast,
mode="lines+markers",
name="AQI Forecast",
line=dict(color="#FF6B35", width=3),
marker=dict(size=4),
fill="tozeroy",
fillcolor="rgba(255, 107, 53, 0.1)",
))
# WHO threshold lines
thresholds = [
(50, "Good", "#00E676"),
(100, "Moderate", "#FFEB3B"),
(150, "Unhealthy (Sensitive)", "#FF9800"),
]
for threshold, label, color in thresholds:
fig.add_hline(
y=threshold,
line_dash="dash",
line_color=color,
line_width=2,
annotation_text=f"{label} ({threshold})",
annotation_position="right",
annotation_font_size=10,
annotation_font_color=color,
)
# Styling
fig.update_layout(
title=dict(
text="72-Hour AQI Forecast (PM2.5)",
font=dict(size=16, color="#E0E0E0"),
),
xaxis_title="Forecast Hour",
yaxis_title="AQI (PM2.5)",
template="plotly_dark",
plot_bgcolor="rgba(17, 17, 17, 0.8)",
paper_bgcolor="rgba(17, 17, 17, 0.9)",
xaxis=dict(
gridcolor="rgba(255,255,255,0.1)",
range=[1, len(aqi_forecast)],
),
yaxis=dict(
gridcolor="rgba(255,255,255,0.1)",
range=[0, max(max(aqi_forecast) * 1.2, 200)],
),
height=400,
margin=dict(l=50, r=50, t=50, b=50),
showlegend=True,
legend=dict(
bgcolor="rgba(0,0,0,0.5)",
font=dict(color="#E0E0E0"),
),
)
return fig
def create_folium_map(
latitude: float,
longitude: float,
risk_map: np.ndarray,
risk_score: float,
aqi_forecast: List[float],
gradcam_b64: Optional[str] = None,
fire_points: Optional[List[Dict]] = None,
) -> str:
"""
Create an interactive Folium map with multiple toggleable layers.
Layers:
1. Fire risk heatmap overlay (HeatMap plugin)
2. Grad-CAM overlay (ImageOverlay)
3. Active fire point markers with confidence popups
4. AQI station markers color-coded by level
Returns:
Complete HTML string for the Folium map.
"""
import folium
from folium.plugins import HeatMap
# Create base map centered on location
m = folium.Map(
location=[latitude, longitude],
zoom_start=9,
tiles="OpenStreetMap",
control_scale=True,
)
# Add Stamen Terrain as alternative base layer
folium.TileLayer(
tiles="https://server.arcgisonline.com/ArcGIS/rest/services/World_Topo_Map/MapServer/tile/{z}/{y}/{x}",
attr="Esri",
name="Topographic",
).add_to(m)
# --- Layer 1: Fire Risk Heatmap ---
risk_layer = folium.FeatureGroup(name="🔥 Fire Risk Heatmap")
# Generate heatmap points from risk map
heat_data = []
h, w = risk_map.shape
lat_range = 0.5 # degrees of lat/lon to cover
lon_range = 0.5
for i in range(0, h, 4): # Sample every 4th pixel for performance
for j in range(0, w, 4):
if risk_map[i, j] > 0.2: # Only show significant risk
lat_point = latitude + lat_range * (0.5 - i / h)
lon_point = longitude + lon_range * (j / w - 0.5)
heat_data.append([lat_point, lon_point, float(risk_map[i, j])])
if heat_data:
HeatMap(
heat_data,
radius=15,
blur=10,
max_zoom=13,
gradient={0.2: "blue", 0.4: "lime", 0.6: "yellow", 0.8: "orange", 1.0: "red"},
).add_to(risk_layer)
risk_layer.add_to(m)
# --- Layer 2: Grad-CAM Overlay ---
if gradcam_b64:
gradcam_layer = folium.FeatureGroup(name="🧠 Grad-CAM Overlay")
# Decode base64 to temporary image
try:
gradcam_bytes = base64.b64decode(gradcam_b64)
gradcam_arr = np.frombuffer(gradcam_bytes, dtype=np.uint8)
gradcam_img = cv2.imdecode(gradcam_arr, cv2.IMREAD_COLOR)
if gradcam_img is not None:
# Convert to PNG data URI
_, buffer = cv2.imencode(".png", gradcam_img)
img_b64 = base64.b64encode(buffer).decode("utf-8")
data_uri = f"data:image/png;base64,{img_b64}"
bounds = [
[latitude - lat_range / 2, longitude - lon_range / 2],
[latitude + lat_range / 2, longitude + lon_range / 2],
]
folium.raster_layers.ImageOverlay(
image=data_uri,
bounds=bounds,
opacity=0.5,
name="Grad-CAM",
).add_to(gradcam_layer)
except Exception as e:
logger.warning(f"Failed to add Grad-CAM overlay: {e}")
gradcam_layer.add_to(m)
# --- Layer 3: Active Fire Point Markers ---
fire_layer = folium.FeatureGroup(name="📍 Active Fire Detections")
if fire_points:
for pt in fire_points:
confidence = pt.get("confidence", "N/A")
folium.CircleMarker(
location=[pt["latitude"], pt["longitude"]],
radius=6,
color="red",
fill=True,
fillColor="red",
fillOpacity=0.8,
popup=folium.Popup(
f"<b>Fire Detection</b><br>"
f"Confidence: {confidence}%<br>"
f"Lat: {pt['latitude']:.4f}<br>"
f"Lon: {pt['longitude']:.4f}",
max_width=200,
),
).add_to(fire_layer)
else:
# Generate synthetic fire points for demo
np.random.seed(42)
num_points = int(risk_score * 10) + 2
for _ in range(num_points):
pt_lat = latitude + np.random.uniform(-0.2, 0.2)
pt_lon = longitude + np.random.uniform(-0.2, 0.2)
conf = np.random.randint(60, 100)
folium.CircleMarker(
location=[pt_lat, pt_lon],
radius=6,
color="red",
fill=True,
fillColor="#FF4444",
fillOpacity=0.8,
popup=folium.Popup(
f"<b>🔥 Fire Detection</b><br>"
f"Confidence: {conf}%<br>"
f"Lat: {pt_lat:.4f}<br>"
f"Lon: {pt_lon:.4f}",
max_width=200,
),
tooltip=f"Fire ({conf}%)",
).add_to(fire_layer)
fire_layer.add_to(m)
# --- Layer 4: AQI Station Markers ---
aqi_layer = folium.FeatureGroup(name="🌬️ AQI Stations")
current_aqi = aqi_forecast[0] if aqi_forecast else 50
aqi_color = _get_aqi_color(current_aqi)
aqi_label = _get_aqi_label(current_aqi)
# Main station marker
folium.Marker(
location=[latitude, longitude],
icon=folium.DivIcon(
html=f"""
<div style="
background: {aqi_color};
color: white;
border-radius: 50%;
width: 36px;
height: 36px;
display: flex;
align-items: center;
justify-content: center;
font-weight: bold;
font-size: 11px;
border: 2px solid white;
box-shadow: 0 2px 6px rgba(0,0,0,0.5);
">{int(current_aqi)}</div>
""",
icon_size=(36, 36),
icon_anchor=(18, 18),
),
popup=folium.Popup(
f"<b>🌬️ AQI Station</b><br>"
f"Current AQI: {current_aqi:.0f}<br>"
f"Level: {aqi_label}<br>"
f"Peak (72h): {max(aqi_forecast):.0f}",
max_width=200,
),
tooltip=f"AQI: {current_aqi:.0f} ({aqi_label})",
).add_to(aqi_layer)
# Nearby synthetic AQI stations
np.random.seed(123)
for i in range(3):
st_lat = latitude + np.random.uniform(-0.3, 0.3)
st_lon = longitude + np.random.uniform(-0.3, 0.3)
st_aqi = current_aqi + np.random.uniform(-20, 30)
st_color = _get_aqi_color(st_aqi)
folium.CircleMarker(
location=[st_lat, st_lon],
radius=10,
color=st_color,
fill=True,
fillColor=st_color,
fillOpacity=0.7,
popup=f"AQI: {st_aqi:.0f}",
tooltip=f"AQI: {st_aqi:.0f}",
).add_to(aqi_layer)
aqi_layer.add_to(m)
# --- Center marker ---
folium.Marker(
location=[latitude, longitude],
icon=folium.Icon(color="blue", icon="info-sign"),
popup=f"<b>Analysis Center</b><br>"
f"Risk: {risk_score:.2f}<br>"
f"Lat: {latitude}, Lon: {longitude}",
).add_to(m)
# --- Layer control ---
folium.LayerControl(collapsed=False).add_to(m)
return m._repr_html_()
def _get_aqi_color(aqi: float) -> str:
"""Map AQI value to display color."""
if aqi <= 50:
return "#00E676" # Green
elif aqi <= 100:
return "#FFEB3B" # Yellow
elif aqi <= 150:
return "#FF9800" # Orange
elif aqi <= 200:
return "#F44336" # Red
elif aqi <= 300:
return "#9C27B0" # Purple
else:
return "#7B1FA2" # Dark red
def _get_aqi_label(aqi: float) -> str:
"""Map AQI value to category label."""
for label, (low, high) in AQI_CATEGORIES.items():
if low <= aqi <= high:
return label
return "Hazardous"
def save_folium_map(html: str, filename: str = "prediction_map.html"):
"""Save Folium map HTML to outputs directory."""
path = OUTPUTS_DIR / filename
with open(path, "w", encoding="utf-8") as f:
f.write(html)
logger.info(f"Folium map saved: {path}")
return str(path)
def create_satellite_comparison(
original_image: np.ndarray,
gradcam_overlay: np.ndarray,
) -> str:
"""
Create side-by-side satellite image vs Grad-CAM overlay.
Returns:
Base64-encoded PNG of the comparison image.
"""
# Normalize original image to displayable range
if original_image.shape[0] <= 4: # CHW format
img_rgb = np.transpose(original_image[:3], (1, 2, 0)) # HWC
else:
img_rgb = original_image[:, :, :3]
img_min, img_max = img_rgb.min(), img_rgb.max()
if img_max - img_min > 0:
img_rgb = (img_rgb - img_min) / (img_max - img_min)
img_uint8 = np.uint8(255 * img_rgb)
# Ensure same size
h, w = img_uint8.shape[:2]
if gradcam_overlay.shape[:2] != (h, w):
gradcam_overlay = cv2.resize(gradcam_overlay, (w, h))
# Create side-by-side composite
separator = np.ones((h, 4, 3), dtype=np.uint8) * 128
composite = np.hstack([
cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR),
separator,
gradcam_overlay,
])
# Add labels
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(composite, "Satellite", (10, 20), font, 0.5, (255, 255, 255), 1)
cv2.putText(composite, "Grad-CAM", (w + 14, 20), font, 0.5, (255, 255, 255), 1)
_, buffer = cv2.imencode(".png", composite)
return base64.b64encode(buffer).decode("utf-8")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
# Test AQI chart
forecast = list(np.random.rand(72) * 100 + 30)
chart_html = create_aqi_forecast_chart(forecast)
print(f"AQI chart HTML length: {len(chart_html)}")
# Test Folium map
risk_map = np.random.rand(128, 128).astype(np.float32)
map_html = create_folium_map(
latitude=37.5,
longitude=-120.3,
risk_map=risk_map,
risk_score=0.65,
aqi_forecast=forecast,
)
print(f"Folium map HTML length: {len(map_html)}")
save_folium_map(map_html)