""" predict.py — End-to-end inference pipeline given latitude, longitude, and date. Orchestrates the full prediction workflow: 1. Fetch satellite/fire data for the given location 2. Fetch 7-day weather + AQI history 3. Preprocess both inputs using saved scalers 4. Run MTL model inference 5. Generate Grad-CAM overlay 6. Return structured prediction results """ import base64 import logging import time from datetime import datetime from pathlib import Path from typing import Dict, Optional, Tuple import cv2 import numpy as np import torch from src.training.config import ( DEVICE, PATCH_SIZE, TIMESERIES_WINDOW, TIMESERIES_FEATURES, AQI_FORECAST_HOURS, BEST_MODEL_PATH, PROCESSED_DIR, RISK_THRESHOLDS, AQI_CATEGORIES, ) logger = logging.getLogger(__name__) class Predictor: """ End-to-end inference engine for wildfire risk + AQI predictions. Loads the trained MTL model, preprocesses inputs, runs inference, generates Grad-CAM overlays, and returns structured results. """ def __init__(self, model=None, device: str = DEVICE): """ Args: model: Pre-loaded model instance. If None, loads from checkpoint. device: Computation device. """ self.device = device if model is not None: self.model = model.to(device).eval() else: self.model = self._load_model() # Load preprocessing scalers if available self.img_normalizer = self._load_image_normalizer() self.ts_scaler = self._load_timeseries_scaler() logger.info("Predictor initialized.") def _load_model(self): """Load trained model from checkpoint.""" from src.models.fusion_model import load_model return load_model(str(BEST_MODEL_PATH), device=self.device) def _load_image_normalizer(self): """Load image normalizer if saved.""" from src.data.preprocess import ImageNormalizer path = PROCESSED_DIR / "image_normalizer.npz" normalizer = ImageNormalizer() if path.exists(): normalizer.load(path) logger.info("Loaded image normalizer from disk.") else: logger.warning("No saved image normalizer found — using identity transform.") normalizer.channel_means = np.zeros(4) normalizer.channel_stds = np.ones(4) return normalizer def _load_timeseries_scaler(self): """Load time-series scaler if saved.""" from src.data.preprocess import TimeSeriesScaler path = PROCESSED_DIR / "timeseries_scaler.npz" scaler = TimeSeriesScaler() if path.exists(): scaler.load(path) logger.info("Loaded timeseries scaler from disk.") else: logger.warning("No saved timeseries scaler found — using identity transform.") scaler._is_fitted = True scaler.scaler.data_min_ = np.zeros(TIMESERIES_FEATURES) scaler.scaler.data_max_ = np.ones(TIMESERIES_FEATURES) scaler.scaler.scale_ = np.ones(TIMESERIES_FEATURES) scaler.scaler.min_ = np.zeros(TIMESERIES_FEATURES) scaler.scaler.n_features_in_ = TIMESERIES_FEATURES return scaler def fetch_inputs( self, latitude: float, longitude: float, date: str, ) -> Tuple[np.ndarray, np.ndarray]: """ Fetch and preprocess inputs from data sources. Args: latitude: Location latitude. longitude: Location longitude. date: Target date (YYYY-MM-DD). Returns: Tuple of (image_patch, timeseries): - image_patch: (1, 4, 128, 128) - timeseries: (1, 7, 6) """ from src.data.fetch_firms import generate_synthetic_fire_data from src.data.fetch_weather import fetch_weather_for_location from src.data.fetch_aqi import fetch_aqi_for_location # Fetch satellite/fire data # For inference, generate a single synthetic patch seeded by location seed = int(abs(latitude * 1000 + longitude * 100)) % (2**31) np.random.seed(seed) image, _ = generate_synthetic_fire_data(num_samples=1) image = image[0:1] # (1, 4, 128, 128) # Fetch weather data weather = fetch_weather_for_location(latitude, longitude, date) # weather shape: (7, 5) # Fetch AQI data aqi_hist, _ = fetch_aqi_for_location(latitude, longitude, date) # aqi_hist shape: (7, 1) # Combine weather + AQI into timeseries timeseries = np.concatenate([weather, aqi_hist], axis=1) # (7, 6) timeseries = timeseries[np.newaxis, ...] # (1, 7, 6) # Normalize image = self.img_normalizer.transform(image) timeseries = self.ts_scaler.transform(timeseries) return image, timeseries @torch.no_grad() def predict( self, latitude: float, longitude: float, date: str, ) -> Dict: """ Run end-to-end prediction for a given location and date. Args: latitude: Location latitude (-90 to 90). longitude: Location longitude (-180 to 180). date: Target date (YYYY-MM-DD). Returns: Dict with all prediction results. """ start_time = time.time() logger.info(f"Predicting for ({latitude}, {longitude}) on {date}...") # Step 1: Fetch and preprocess inputs image_np, timeseries_np = self.fetch_inputs(latitude, longitude, date) # Step 2: Convert to tensors image_tensor = torch.tensor(image_np, dtype=torch.float32).to(self.device) ts_tensor = torch.tensor(timeseries_np, dtype=torch.float32).to(self.device) # Step 3: Run MTL model inference output = self.model(image_tensor, ts_tensor) risk_map = output["risk_map"][0].cpu().numpy() # (128, 128) aqi_forecast = output["aqi_forecast"][0].cpu().numpy() * 500.0 # (72,) # Step 4: Compute risk level risk_score = float(risk_map.mean()) risk_level = self._classify_risk(risk_score) # Step 5: Generate Grad-CAM overlay heatmap_b64, gradcam_b64 = self._generate_visualizations( image_tensor, ts_tensor, risk_map, image_np ) # Step 6: AQI analysis peak_hour = int(np.argmax(aqi_forecast)) peak_value = float(aqi_forecast[peak_hour]) elapsed = time.time() - start_time logger.info(f"Prediction complete in {elapsed:.2f}s — " f"Risk: {risk_level} ({risk_score:.3f})") return { "risk_level": risk_level, "risk_score": round(risk_score, 4), "risk_map": risk_map, "heatmap_base64": heatmap_b64, "gradcam_base64": gradcam_b64, "aqi_forecast": aqi_forecast.tolist(), "forecast_hours": list(range(1, AQI_FORECAST_HOURS + 1)), "peak_aqi_hour": peak_hour + 1, "peak_aqi_value": round(peak_value, 1), "latitude": latitude, "longitude": longitude, "date": date, "prediction_timestamp": datetime.now().isoformat(), "inference_time_ms": round(elapsed * 1000, 1), } def _classify_risk(self, score: float) -> str: """Classify risk score into level.""" for level, (low, high) in RISK_THRESHOLDS.items(): if low <= score < high: return level return "Extreme" def _generate_visualizations( self, image_tensor: torch.Tensor, ts_tensor: torch.Tensor, risk_map: np.ndarray, image_np: np.ndarray, ) -> Tuple[str, str]: """ Generate heatmap and Grad-CAM base64 images. Returns: Tuple of (heatmap_base64, gradcam_base64). """ # Heatmap: risk map as colored image heatmap_uint8 = np.uint8(255 * np.clip(risk_map, 0, 1)) heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) _, heatmap_buf = cv2.imencode(".png", heatmap_color) heatmap_b64 = base64.b64encode(heatmap_buf).decode("utf-8") # Grad-CAM overlay try: from src.models.gradcam import GradCAM gradcam = GradCAM(self.model) overlay = gradcam.generate_overlay(image_tensor, ts_tensor) _, gradcam_buf = cv2.imencode(".png", overlay) gradcam_b64 = base64.b64encode(gradcam_buf).decode("utf-8") except Exception as e: logger.warning(f"Grad-CAM generation failed: {e}") gradcam_b64 = heatmap_b64 # Fallback to regular heatmap return heatmap_b64, gradcam_b64 # Module-level singleton predictor _predictor: Optional[Predictor] = None def get_predictor(model=None) -> Predictor: """Get or create singleton Predictor instance.""" global _predictor if _predictor is None: _predictor = Predictor(model=model) return _predictor def predict(latitude: float, longitude: float, date: str) -> Dict: """ Convenience function for single prediction. Args: latitude: Location latitude. longitude: Location longitude. date: Target date (YYYY-MM-DD). Returns: Prediction results dict. """ predictor = get_predictor() return predictor.predict(latitude, longitude, date) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) result = predict(37.5, -120.3, "2024-06-15") print(f"\nRisk Level: {result['risk_level']}") print(f"Risk Score: {result['risk_score']}") print(f"Peak AQI: {result['peak_aqi_value']} at hour {result['peak_aqi_hour']}") print(f"Inference Time: {result['inference_time_ms']}ms") print(f"Forecast length: {len(result['aqi_forecast'])} hours")