| """ |
| predict.py — End-to-end inference pipeline given latitude, longitude, and date. |
| |
| Orchestrates the full prediction workflow: |
| 1. Fetch satellite/fire data for the given location |
| 2. Fetch 7-day weather + AQI history |
| 3. Preprocess both inputs using saved scalers |
| 4. Run MTL model inference |
| 5. Generate Grad-CAM overlay |
| 6. Return structured prediction results |
| """ |
|
|
| import base64 |
| import logging |
| import time |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Dict, Optional, Tuple |
|
|
| import cv2 |
| import numpy as np |
| import torch |
|
|
| from src.training.config import ( |
| DEVICE, PATCH_SIZE, TIMESERIES_WINDOW, TIMESERIES_FEATURES, |
| AQI_FORECAST_HOURS, BEST_MODEL_PATH, PROCESSED_DIR, |
| RISK_THRESHOLDS, AQI_CATEGORIES, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class Predictor: |
| """ |
| End-to-end inference engine for wildfire risk + AQI predictions. |
| |
| Loads the trained MTL model, preprocesses inputs, runs inference, |
| generates Grad-CAM overlays, and returns structured results. |
| """ |
|
|
| def __init__(self, model=None, device: str = DEVICE): |
| """ |
| Args: |
| model: Pre-loaded model instance. If None, loads from checkpoint. |
| device: Computation device. |
| """ |
| self.device = device |
|
|
| if model is not None: |
| self.model = model.to(device).eval() |
| else: |
| self.model = self._load_model() |
|
|
| |
| self.img_normalizer = self._load_image_normalizer() |
| self.ts_scaler = self._load_timeseries_scaler() |
|
|
| logger.info("Predictor initialized.") |
|
|
| def _load_model(self): |
| """Load trained model from checkpoint.""" |
| from src.models.fusion_model import load_model |
| return load_model(str(BEST_MODEL_PATH), device=self.device) |
|
|
| def _load_image_normalizer(self): |
| """Load image normalizer if saved.""" |
| from src.data.preprocess import ImageNormalizer |
| path = PROCESSED_DIR / "image_normalizer.npz" |
| normalizer = ImageNormalizer() |
| if path.exists(): |
| normalizer.load(path) |
| logger.info("Loaded image normalizer from disk.") |
| else: |
| logger.warning("No saved image normalizer found — using identity transform.") |
| normalizer.channel_means = np.zeros(4) |
| normalizer.channel_stds = np.ones(4) |
| return normalizer |
|
|
| def _load_timeseries_scaler(self): |
| """Load time-series scaler if saved.""" |
| from src.data.preprocess import TimeSeriesScaler |
| path = PROCESSED_DIR / "timeseries_scaler.npz" |
| scaler = TimeSeriesScaler() |
| if path.exists(): |
| scaler.load(path) |
| logger.info("Loaded timeseries scaler from disk.") |
| else: |
| logger.warning("No saved timeseries scaler found — using identity transform.") |
| scaler._is_fitted = True |
| scaler.scaler.data_min_ = np.zeros(TIMESERIES_FEATURES) |
| scaler.scaler.data_max_ = np.ones(TIMESERIES_FEATURES) |
| scaler.scaler.scale_ = np.ones(TIMESERIES_FEATURES) |
| scaler.scaler.min_ = np.zeros(TIMESERIES_FEATURES) |
| scaler.scaler.n_features_in_ = TIMESERIES_FEATURES |
| return scaler |
|
|
| def fetch_inputs( |
| self, |
| latitude: float, |
| longitude: float, |
| date: str, |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| """ |
| Fetch and preprocess inputs from data sources. |
| |
| Args: |
| latitude: Location latitude. |
| longitude: Location longitude. |
| date: Target date (YYYY-MM-DD). |
| |
| Returns: |
| Tuple of (image_patch, timeseries): |
| - image_patch: (1, 4, 128, 128) |
| - timeseries: (1, 7, 6) |
| """ |
| from src.data.fetch_firms import generate_synthetic_fire_data |
| from src.data.fetch_weather import fetch_weather_for_location |
| from src.data.fetch_aqi import fetch_aqi_for_location |
|
|
| |
| |
| seed = int(abs(latitude * 1000 + longitude * 100)) % (2**31) |
| np.random.seed(seed) |
| image, _ = generate_synthetic_fire_data(num_samples=1) |
| image = image[0:1] |
|
|
| |
| weather = fetch_weather_for_location(latitude, longitude, date) |
| |
|
|
| |
| aqi_hist, _ = fetch_aqi_for_location(latitude, longitude, date) |
| |
|
|
| |
| timeseries = np.concatenate([weather, aqi_hist], axis=1) |
| timeseries = timeseries[np.newaxis, ...] |
|
|
| |
| image = self.img_normalizer.transform(image) |
| timeseries = self.ts_scaler.transform(timeseries) |
|
|
| return image, timeseries |
|
|
| @torch.no_grad() |
| def predict( |
| self, |
| latitude: float, |
| longitude: float, |
| date: str, |
| ) -> Dict: |
| """ |
| Run end-to-end prediction for a given location and date. |
| |
| Args: |
| latitude: Location latitude (-90 to 90). |
| longitude: Location longitude (-180 to 180). |
| date: Target date (YYYY-MM-DD). |
| |
| Returns: |
| Dict with all prediction results. |
| """ |
| start_time = time.time() |
| logger.info(f"Predicting for ({latitude}, {longitude}) on {date}...") |
|
|
| |
| image_np, timeseries_np = self.fetch_inputs(latitude, longitude, date) |
|
|
| |
| image_tensor = torch.tensor(image_np, dtype=torch.float32).to(self.device) |
| ts_tensor = torch.tensor(timeseries_np, dtype=torch.float32).to(self.device) |
|
|
| |
| output = self.model(image_tensor, ts_tensor) |
|
|
| risk_map = output["risk_map"][0].cpu().numpy() |
| aqi_forecast = output["aqi_forecast"][0].cpu().numpy() * 500.0 |
|
|
| |
| risk_score = float(risk_map.mean()) |
| risk_level = self._classify_risk(risk_score) |
|
|
| |
| heatmap_b64, gradcam_b64 = self._generate_visualizations( |
| image_tensor, ts_tensor, risk_map, image_np |
| ) |
|
|
| |
| peak_hour = int(np.argmax(aqi_forecast)) |
| peak_value = float(aqi_forecast[peak_hour]) |
|
|
| elapsed = time.time() - start_time |
| logger.info(f"Prediction complete in {elapsed:.2f}s — " |
| f"Risk: {risk_level} ({risk_score:.3f})") |
|
|
| return { |
| "risk_level": risk_level, |
| "risk_score": round(risk_score, 4), |
| "risk_map": risk_map, |
| "heatmap_base64": heatmap_b64, |
| "gradcam_base64": gradcam_b64, |
| "aqi_forecast": aqi_forecast.tolist(), |
| "forecast_hours": list(range(1, AQI_FORECAST_HOURS + 1)), |
| "peak_aqi_hour": peak_hour + 1, |
| "peak_aqi_value": round(peak_value, 1), |
| "latitude": latitude, |
| "longitude": longitude, |
| "date": date, |
| "prediction_timestamp": datetime.now().isoformat(), |
| "inference_time_ms": round(elapsed * 1000, 1), |
| } |
|
|
| def _classify_risk(self, score: float) -> str: |
| """Classify risk score into level.""" |
| for level, (low, high) in RISK_THRESHOLDS.items(): |
| if low <= score < high: |
| return level |
| return "Extreme" |
|
|
| def _generate_visualizations( |
| self, |
| image_tensor: torch.Tensor, |
| ts_tensor: torch.Tensor, |
| risk_map: np.ndarray, |
| image_np: np.ndarray, |
| ) -> Tuple[str, str]: |
| """ |
| Generate heatmap and Grad-CAM base64 images. |
| |
| Returns: |
| Tuple of (heatmap_base64, gradcam_base64). |
| """ |
| |
| heatmap_uint8 = np.uint8(255 * np.clip(risk_map, 0, 1)) |
| heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) |
| _, heatmap_buf = cv2.imencode(".png", heatmap_color) |
| heatmap_b64 = base64.b64encode(heatmap_buf).decode("utf-8") |
|
|
| |
| try: |
| from src.models.gradcam import GradCAM |
| gradcam = GradCAM(self.model) |
| overlay = gradcam.generate_overlay(image_tensor, ts_tensor) |
| _, gradcam_buf = cv2.imencode(".png", overlay) |
| gradcam_b64 = base64.b64encode(gradcam_buf).decode("utf-8") |
| except Exception as e: |
| logger.warning(f"Grad-CAM generation failed: {e}") |
| gradcam_b64 = heatmap_b64 |
|
|
| return heatmap_b64, gradcam_b64 |
|
|
|
|
| |
| _predictor: Optional[Predictor] = None |
|
|
|
|
| def get_predictor(model=None) -> Predictor: |
| """Get or create singleton Predictor instance.""" |
| global _predictor |
| if _predictor is None: |
| _predictor = Predictor(model=model) |
| return _predictor |
|
|
|
|
| def predict(latitude: float, longitude: float, date: str) -> Dict: |
| """ |
| Convenience function for single prediction. |
| |
| Args: |
| latitude: Location latitude. |
| longitude: Location longitude. |
| date: Target date (YYYY-MM-DD). |
| |
| Returns: |
| Prediction results dict. |
| """ |
| predictor = get_predictor() |
| return predictor.predict(latitude, longitude, date) |
|
|
|
|
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.INFO) |
|
|
| result = predict(37.5, -120.3, "2024-06-15") |
|
|
| print(f"\nRisk Level: {result['risk_level']}") |
| print(f"Risk Score: {result['risk_score']}") |
| print(f"Peak AQI: {result['peak_aqi_value']} at hour {result['peak_aqi_hour']}") |
| print(f"Inference Time: {result['inference_time_ms']}ms") |
| print(f"Forecast length: {len(result['aqi_forecast'])} hours") |
|
|