krupal02's picture
Deploy Multi-Hazard Warning System - MTL model for wildfire risk + AQI forecasting
d5b0af1
Raw
History Blame Contribute Delete
9.93 kB
"""
predict.py — End-to-end inference pipeline given latitude, longitude, and date.
Orchestrates the full prediction workflow:
1. Fetch satellite/fire data for the given location
2. Fetch 7-day weather + AQI history
3. Preprocess both inputs using saved scalers
4. Run MTL model inference
5. Generate Grad-CAM overlay
6. Return structured prediction results
"""
import base64
import logging
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple
import cv2
import numpy as np
import torch
from src.training.config import (
DEVICE, PATCH_SIZE, TIMESERIES_WINDOW, TIMESERIES_FEATURES,
AQI_FORECAST_HOURS, BEST_MODEL_PATH, PROCESSED_DIR,
RISK_THRESHOLDS, AQI_CATEGORIES,
)
logger = logging.getLogger(__name__)
class Predictor:
"""
End-to-end inference engine for wildfire risk + AQI predictions.
Loads the trained MTL model, preprocesses inputs, runs inference,
generates Grad-CAM overlays, and returns structured results.
"""
def __init__(self, model=None, device: str = DEVICE):
"""
Args:
model: Pre-loaded model instance. If None, loads from checkpoint.
device: Computation device.
"""
self.device = device
if model is not None:
self.model = model.to(device).eval()
else:
self.model = self._load_model()
# Load preprocessing scalers if available
self.img_normalizer = self._load_image_normalizer()
self.ts_scaler = self._load_timeseries_scaler()
logger.info("Predictor initialized.")
def _load_model(self):
"""Load trained model from checkpoint."""
from src.models.fusion_model import load_model
return load_model(str(BEST_MODEL_PATH), device=self.device)
def _load_image_normalizer(self):
"""Load image normalizer if saved."""
from src.data.preprocess import ImageNormalizer
path = PROCESSED_DIR / "image_normalizer.npz"
normalizer = ImageNormalizer()
if path.exists():
normalizer.load(path)
logger.info("Loaded image normalizer from disk.")
else:
logger.warning("No saved image normalizer found — using identity transform.")
normalizer.channel_means = np.zeros(4)
normalizer.channel_stds = np.ones(4)
return normalizer
def _load_timeseries_scaler(self):
"""Load time-series scaler if saved."""
from src.data.preprocess import TimeSeriesScaler
path = PROCESSED_DIR / "timeseries_scaler.npz"
scaler = TimeSeriesScaler()
if path.exists():
scaler.load(path)
logger.info("Loaded timeseries scaler from disk.")
else:
logger.warning("No saved timeseries scaler found — using identity transform.")
scaler._is_fitted = True
scaler.scaler.data_min_ = np.zeros(TIMESERIES_FEATURES)
scaler.scaler.data_max_ = np.ones(TIMESERIES_FEATURES)
scaler.scaler.scale_ = np.ones(TIMESERIES_FEATURES)
scaler.scaler.min_ = np.zeros(TIMESERIES_FEATURES)
scaler.scaler.n_features_in_ = TIMESERIES_FEATURES
return scaler
def fetch_inputs(
self,
latitude: float,
longitude: float,
date: str,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Fetch and preprocess inputs from data sources.
Args:
latitude: Location latitude.
longitude: Location longitude.
date: Target date (YYYY-MM-DD).
Returns:
Tuple of (image_patch, timeseries):
- image_patch: (1, 4, 128, 128)
- timeseries: (1, 7, 6)
"""
from src.data.fetch_firms import generate_synthetic_fire_data
from src.data.fetch_weather import fetch_weather_for_location
from src.data.fetch_aqi import fetch_aqi_for_location
# Fetch satellite/fire data
# For inference, generate a single synthetic patch seeded by location
seed = int(abs(latitude * 1000 + longitude * 100)) % (2**31)
np.random.seed(seed)
image, _ = generate_synthetic_fire_data(num_samples=1)
image = image[0:1] # (1, 4, 128, 128)
# Fetch weather data
weather = fetch_weather_for_location(latitude, longitude, date)
# weather shape: (7, 5)
# Fetch AQI data
aqi_hist, _ = fetch_aqi_for_location(latitude, longitude, date)
# aqi_hist shape: (7, 1)
# Combine weather + AQI into timeseries
timeseries = np.concatenate([weather, aqi_hist], axis=1) # (7, 6)
timeseries = timeseries[np.newaxis, ...] # (1, 7, 6)
# Normalize
image = self.img_normalizer.transform(image)
timeseries = self.ts_scaler.transform(timeseries)
return image, timeseries
@torch.no_grad()
def predict(
self,
latitude: float,
longitude: float,
date: str,
) -> Dict:
"""
Run end-to-end prediction for a given location and date.
Args:
latitude: Location latitude (-90 to 90).
longitude: Location longitude (-180 to 180).
date: Target date (YYYY-MM-DD).
Returns:
Dict with all prediction results.
"""
start_time = time.time()
logger.info(f"Predicting for ({latitude}, {longitude}) on {date}...")
# Step 1: Fetch and preprocess inputs
image_np, timeseries_np = self.fetch_inputs(latitude, longitude, date)
# Step 2: Convert to tensors
image_tensor = torch.tensor(image_np, dtype=torch.float32).to(self.device)
ts_tensor = torch.tensor(timeseries_np, dtype=torch.float32).to(self.device)
# Step 3: Run MTL model inference
output = self.model(image_tensor, ts_tensor)
risk_map = output["risk_map"][0].cpu().numpy() # (128, 128)
aqi_forecast = output["aqi_forecast"][0].cpu().numpy() * 500.0 # (72,)
# Step 4: Compute risk level
risk_score = float(risk_map.mean())
risk_level = self._classify_risk(risk_score)
# Step 5: Generate Grad-CAM overlay
heatmap_b64, gradcam_b64 = self._generate_visualizations(
image_tensor, ts_tensor, risk_map, image_np
)
# Step 6: AQI analysis
peak_hour = int(np.argmax(aqi_forecast))
peak_value = float(aqi_forecast[peak_hour])
elapsed = time.time() - start_time
logger.info(f"Prediction complete in {elapsed:.2f}s — "
f"Risk: {risk_level} ({risk_score:.3f})")
return {
"risk_level": risk_level,
"risk_score": round(risk_score, 4),
"risk_map": risk_map,
"heatmap_base64": heatmap_b64,
"gradcam_base64": gradcam_b64,
"aqi_forecast": aqi_forecast.tolist(),
"forecast_hours": list(range(1, AQI_FORECAST_HOURS + 1)),
"peak_aqi_hour": peak_hour + 1,
"peak_aqi_value": round(peak_value, 1),
"latitude": latitude,
"longitude": longitude,
"date": date,
"prediction_timestamp": datetime.now().isoformat(),
"inference_time_ms": round(elapsed * 1000, 1),
}
def _classify_risk(self, score: float) -> str:
"""Classify risk score into level."""
for level, (low, high) in RISK_THRESHOLDS.items():
if low <= score < high:
return level
return "Extreme"
def _generate_visualizations(
self,
image_tensor: torch.Tensor,
ts_tensor: torch.Tensor,
risk_map: np.ndarray,
image_np: np.ndarray,
) -> Tuple[str, str]:
"""
Generate heatmap and Grad-CAM base64 images.
Returns:
Tuple of (heatmap_base64, gradcam_base64).
"""
# Heatmap: risk map as colored image
heatmap_uint8 = np.uint8(255 * np.clip(risk_map, 0, 1))
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
_, heatmap_buf = cv2.imencode(".png", heatmap_color)
heatmap_b64 = base64.b64encode(heatmap_buf).decode("utf-8")
# Grad-CAM overlay
try:
from src.models.gradcam import GradCAM
gradcam = GradCAM(self.model)
overlay = gradcam.generate_overlay(image_tensor, ts_tensor)
_, gradcam_buf = cv2.imencode(".png", overlay)
gradcam_b64 = base64.b64encode(gradcam_buf).decode("utf-8")
except Exception as e:
logger.warning(f"Grad-CAM generation failed: {e}")
gradcam_b64 = heatmap_b64 # Fallback to regular heatmap
return heatmap_b64, gradcam_b64
# Module-level singleton predictor
_predictor: Optional[Predictor] = None
def get_predictor(model=None) -> Predictor:
"""Get or create singleton Predictor instance."""
global _predictor
if _predictor is None:
_predictor = Predictor(model=model)
return _predictor
def predict(latitude: float, longitude: float, date: str) -> Dict:
"""
Convenience function for single prediction.
Args:
latitude: Location latitude.
longitude: Location longitude.
date: Target date (YYYY-MM-DD).
Returns:
Prediction results dict.
"""
predictor = get_predictor()
return predictor.predict(latitude, longitude, date)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
result = predict(37.5, -120.3, "2024-06-15")
print(f"\nRisk Level: {result['risk_level']}")
print(f"Risk Score: {result['risk_score']}")
print(f"Peak AQI: {result['peak_aqi_value']} at hour {result['peak_aqi_hour']}")
print(f"Inference Time: {result['inference_time_ms']}ms")
print(f"Forecast length: {len(result['aqi_forecast'])} hours")