""" fusion_model.py — Multi-Task Learning fusion model with dual output heads. This is the core MTL architecture. A single model with: - CNN branch (ResNet-50) for spatial features - LSTM branch for temporal features - Shared FC layers fusing both representations - Task 1 Head: Wildfire risk heatmap (128×128, sigmoid) - Task 2 Head: AQI forecast (72 hourly values, linear) The shared backbone learns representations useful for BOTH tasks, enabling knowledge transfer between wildfire risk and air quality. """ import logging from typing import Dict, Optional, Tuple import torch import torch.nn as nn from src.models.cnn_branch import CNNBranch from src.models.lstm_branch import LSTMBranch from src.training.config import ( CNN_FEATURE_DIM, LSTM_FEATURE_DIM, FUSION_DIM, SHARED_FC_DIMS, DROPOUT_RATE, PATCH_SIZE, AQI_FORECAST_HOURS, IMAGE_CHANNELS, TIMESERIES_FEATURES, DEVICE, ) logger = logging.getLogger(__name__) class MultiTaskFusionModel(nn.Module): """ Multi-Task Learning model combining CNN and LSTM branches. Architecture: CNN Branch: (batch, 4, 128, 128) → (batch, 2048) LSTM Branch: (batch, 7, 6) → (batch, 256) Fusion: concat → 2304 → 512 → 256 (shared) Head 1: 256 → 128×128 (wildfire risk heatmap) Head 2: 256 → 72 (AQI hourly forecast) This is explicitly a Multi-Task Learning (MTL) model — one model, two tasks, shared backbone, task-specific heads. """ def __init__( self, pretrained_cnn: bool = True, freeze_cnn_early: bool = False, ): super().__init__() # ---- Branch 1: CNN Spatial Encoder ---- self.cnn_branch = CNNBranch( in_channels=IMAGE_CHANNELS, out_features=CNN_FEATURE_DIM, pretrained=pretrained_cnn, freeze_early=freeze_cnn_early, ) # ---- Branch 2: LSTM Temporal Encoder ---- self.lstm_branch = LSTMBranch( input_size=TIMESERIES_FEATURES, ) # ---- Shared Fusion Layers ---- # Concatenate: CNN_FEATURE_DIM + LSTM_FEATURE_DIM = 2048 + 256 = 2304 fusion_layers = [] prev_dim = FUSION_DIM for dim in SHARED_FC_DIMS: fusion_layers.extend([ nn.Linear(prev_dim, dim), nn.BatchNorm1d(dim), nn.ReLU(inplace=True), nn.Dropout(DROPOUT_RATE), ]) prev_dim = dim self.shared_fc = nn.Sequential(*fusion_layers) # ---- Task 1 Head: Wildfire Risk Heatmap ---- # 256 → 128×128 spatial map self.risk_head = nn.Sequential( nn.Linear(SHARED_FC_DIMS[-1], 512), nn.ReLU(inplace=True), nn.Dropout(DROPOUT_RATE), nn.Linear(512, 1024), nn.ReLU(inplace=True), nn.Linear(1024, PATCH_SIZE * PATCH_SIZE), nn.Sigmoid(), # Output probabilities [0, 1] ) # ---- Task 2 Head: AQI Forecast ---- # (shared_dim + lstm_dim) → 72 hourly values # Adding skip connection from lstm_features directly to prevent CNN saturation self.aqi_head = nn.Sequential( nn.Linear(SHARED_FC_DIMS[-1] + LSTM_FEATURE_DIM, 256), nn.ReLU(inplace=True), nn.Dropout(DROPOUT_RATE), nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Linear(128, AQI_FORECAST_HOURS), nn.Sigmoid(), # Bound output to [0, 1]; unscaled by ×500 at inference ) self._log_architecture() def _log_architecture(self): """Log model architecture summary.""" total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) logger.info( f"MultiTaskFusionModel initialized:\n" f" CNN branch: {IMAGE_CHANNELS}ch → {CNN_FEATURE_DIM}d\n" f" LSTM branch: {TIMESERIES_FEATURES}feat → {LSTM_FEATURE_DIM}d\n" f" Fusion: {FUSION_DIM}d → {SHARED_FC_DIMS}d\n" f" Risk head: {SHARED_FC_DIMS[-1]}d → {PATCH_SIZE}×{PATCH_SIZE}\n" f" AQI head: {SHARED_FC_DIMS[-1]}d → {AQI_FORECAST_HOURS}h\n" f" Parameters: {total:,} total, {trainable:,} trainable" ) def forward( self, image: torch.Tensor, timeseries: torch.Tensor, ) -> Dict[str, torch.Tensor]: """ Forward pass through the full MTL model. Args: image: (batch, 4, 128, 128) satellite image patch. timeseries: (batch, 7, 6) weather + AQI sequence. Returns: Dict with: - 'risk_map': (batch, 128, 128) wildfire risk heatmap - 'aqi_forecast': (batch, 72) hourly AQI predictions - 'shared_features': (batch, 256) shared representation """ # Extract features from both branches cnn_features = self.cnn_branch(image) # (batch, 2048) lstm_features = self.lstm_branch(timeseries) # (batch, 256) # Fuse representations fused = torch.cat([cnn_features, lstm_features], dim=1) # (batch, 2304) shared = self.shared_fc(fused) # (batch, 256) # Task-specific heads risk_flat = self.risk_head(shared) # (batch, 128*128) risk_map = risk_flat.view(-1, PATCH_SIZE, PATCH_SIZE) # (batch, 128, 128) # Skip connection: guarantee temporal structure guides regression. aqi_input = torch.cat([shared, lstm_features], dim=1) aqi_forecast = self.aqi_head(aqi_input) # (batch, 72) return { "risk_map": risk_map, "aqi_forecast": aqi_forecast, "shared_features": shared, } def forward_cnn_only(self, image: torch.Tensor) -> Dict[str, torch.Tensor]: """ Forward pass using only the CNN branch (ablation Config A). Uses zero-filled LSTM features. """ cnn_features = self.cnn_branch(image) lstm_features = torch.zeros( image.size(0), LSTM_FEATURE_DIM, device=image.device ) fused = torch.cat([cnn_features, lstm_features], dim=1) shared = self.shared_fc(fused) risk_flat = self.risk_head(shared) risk_map = risk_flat.view(-1, PATCH_SIZE, PATCH_SIZE) return {"risk_map": risk_map, "shared_features": shared} def forward_lstm_only(self, timeseries: torch.Tensor) -> Dict[str, torch.Tensor]: """ Forward pass using only the LSTM branch (ablation Config B). Uses zero-filled CNN features. """ cnn_features = torch.zeros( timeseries.size(0), CNN_FEATURE_DIM, device=timeseries.device ) lstm_features = self.lstm_branch(timeseries) fused = torch.cat([cnn_features, lstm_features], dim=1) shared = self.shared_fc(fused) aqi_input = torch.cat([shared, lstm_features], dim=1) aqi_forecast = self.aqi_head(aqi_input) return {"aqi_forecast": aqi_forecast, "shared_features": shared} def get_risk_score(self, risk_map: torch.Tensor) -> Tuple[float, str]: """ Compute overall risk score and level from a risk heatmap. Args: risk_map: (128, 128) risk probability map. Returns: Tuple of (risk_score, risk_level). """ score = float(risk_map.mean()) if score < 0.25: level = "Low" elif score < 0.50: level = "Medium" elif score < 0.75: level = "High" else: level = "Extreme" return score, level def load_model( checkpoint_path: Optional[str] = None, device: str = DEVICE, ) -> MultiTaskFusionModel: """ Load a MultiTaskFusionModel, optionally from a checkpoint. Supports both full-precision and dynamically quantized checkpoints. For quantized models, the architecture is first initialized, then dynamic quantization is applied before loading the state dict. Args: checkpoint_path: Path to .pth checkpoint file. device: Device to load model onto. Returns: Loaded model in eval mode. """ model = MultiTaskFusionModel(pretrained_cnn=True) if checkpoint_path and Path(checkpoint_path).exists(): logger.info(f"Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint # Detect if checkpoint is from a quantized model quantized_keys = [k for k in state_dict.keys() if "_packed_params" in k or "scale" in k] is_quantized = len(quantized_keys) > 5 # Quantized models have many such keys if is_quantized: logger.info("Detected quantized checkpoint — applying dynamic quantization first...") model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear, torch.nn.LSTM}, dtype=torch.qint8, ) model.load_state_dict(state_dict, strict=False) else: model.load_state_dict(state_dict) logger.info("Checkpoint loaded successfully.") else: logger.info("No checkpoint provided — using fresh model with pretrained CNN backbone.") model = model.to(device) model.eval() return model from pathlib import Path if __name__ == "__main__": logging.basicConfig(level=logging.INFO) model = MultiTaskFusionModel(pretrained_cnn=False) # Test full forward pass img = torch.randn(2, 4, 128, 128) ts = torch.randn(2, 7, 6) output = model(img, ts) print(f"\nFull MTL forward pass:") print(f" Risk map: {output['risk_map'].shape}") print(f" AQI forecast: {output['aqi_forecast'].shape}") print(f" Shared feat: {output['shared_features'].shape}") # Test ablation modes cnn_out = model.forward_cnn_only(img) print(f"\nCNN-only (Config A):") print(f" Risk map: {cnn_out['risk_map'].shape}") lstm_out = model.forward_lstm_only(ts) print(f"\nLSTM-only (Config B):") print(f" AQI forecast: {lstm_out['aqi_forecast'].shape}") # Risk score score, level = model.get_risk_score(output["risk_map"][0]) print(f"\nRisk score: {score:.3f} ({level})")