| """ |
| fusion_model.py — Multi-Task Learning fusion model with dual output heads. |
| |
| This is the core MTL architecture. A single model with: |
| - CNN branch (ResNet-50) for spatial features |
| - LSTM branch for temporal features |
| - Shared FC layers fusing both representations |
| - Task 1 Head: Wildfire risk heatmap (128×128, sigmoid) |
| - Task 2 Head: AQI forecast (72 hourly values, linear) |
| |
| The shared backbone learns representations useful for BOTH tasks, |
| enabling knowledge transfer between wildfire risk and air quality. |
| """ |
|
|
| import logging |
| from typing import Dict, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from src.models.cnn_branch import CNNBranch |
| from src.models.lstm_branch import LSTMBranch |
| from src.training.config import ( |
| CNN_FEATURE_DIM, LSTM_FEATURE_DIM, FUSION_DIM, |
| SHARED_FC_DIMS, DROPOUT_RATE, |
| PATCH_SIZE, AQI_FORECAST_HOURS, IMAGE_CHANNELS, TIMESERIES_FEATURES, |
| DEVICE, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class MultiTaskFusionModel(nn.Module): |
| """ |
| Multi-Task Learning model combining CNN and LSTM branches. |
| |
| Architecture: |
| CNN Branch: (batch, 4, 128, 128) → (batch, 2048) |
| LSTM Branch: (batch, 7, 6) → (batch, 256) |
| Fusion: concat → 2304 → 512 → 256 (shared) |
| Head 1: 256 → 128×128 (wildfire risk heatmap) |
| Head 2: 256 → 72 (AQI hourly forecast) |
| |
| This is explicitly a Multi-Task Learning (MTL) model — one model, |
| two tasks, shared backbone, task-specific heads. |
| """ |
|
|
| def __init__( |
| self, |
| pretrained_cnn: bool = True, |
| freeze_cnn_early: bool = False, |
| ): |
| super().__init__() |
|
|
| |
| self.cnn_branch = CNNBranch( |
| in_channels=IMAGE_CHANNELS, |
| out_features=CNN_FEATURE_DIM, |
| pretrained=pretrained_cnn, |
| freeze_early=freeze_cnn_early, |
| ) |
|
|
| |
| self.lstm_branch = LSTMBranch( |
| input_size=TIMESERIES_FEATURES, |
| ) |
|
|
| |
| |
| fusion_layers = [] |
| prev_dim = FUSION_DIM |
| for dim in SHARED_FC_DIMS: |
| fusion_layers.extend([ |
| nn.Linear(prev_dim, dim), |
| nn.BatchNorm1d(dim), |
| nn.ReLU(inplace=True), |
| nn.Dropout(DROPOUT_RATE), |
| ]) |
| prev_dim = dim |
|
|
| self.shared_fc = nn.Sequential(*fusion_layers) |
|
|
| |
| |
| self.risk_head = nn.Sequential( |
| nn.Linear(SHARED_FC_DIMS[-1], 512), |
| nn.ReLU(inplace=True), |
| nn.Dropout(DROPOUT_RATE), |
| nn.Linear(512, 1024), |
| nn.ReLU(inplace=True), |
| nn.Linear(1024, PATCH_SIZE * PATCH_SIZE), |
| nn.Sigmoid(), |
| ) |
|
|
| |
| |
| |
| self.aqi_head = nn.Sequential( |
| nn.Linear(SHARED_FC_DIMS[-1] + LSTM_FEATURE_DIM, 256), |
| nn.ReLU(inplace=True), |
| nn.Dropout(DROPOUT_RATE), |
| nn.Linear(256, 128), |
| nn.ReLU(inplace=True), |
| nn.Linear(128, AQI_FORECAST_HOURS), |
| nn.Sigmoid(), |
| ) |
|
|
| self._log_architecture() |
|
|
| def _log_architecture(self): |
| """Log model architecture summary.""" |
| total = sum(p.numel() for p in self.parameters()) |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| logger.info( |
| f"MultiTaskFusionModel initialized:\n" |
| f" CNN branch: {IMAGE_CHANNELS}ch → {CNN_FEATURE_DIM}d\n" |
| f" LSTM branch: {TIMESERIES_FEATURES}feat → {LSTM_FEATURE_DIM}d\n" |
| f" Fusion: {FUSION_DIM}d → {SHARED_FC_DIMS}d\n" |
| f" Risk head: {SHARED_FC_DIMS[-1]}d → {PATCH_SIZE}×{PATCH_SIZE}\n" |
| f" AQI head: {SHARED_FC_DIMS[-1]}d → {AQI_FORECAST_HOURS}h\n" |
| f" Parameters: {total:,} total, {trainable:,} trainable" |
| ) |
|
|
| def forward( |
| self, |
| image: torch.Tensor, |
| timeseries: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass through the full MTL model. |
| |
| Args: |
| image: (batch, 4, 128, 128) satellite image patch. |
| timeseries: (batch, 7, 6) weather + AQI sequence. |
| |
| Returns: |
| Dict with: |
| - 'risk_map': (batch, 128, 128) wildfire risk heatmap |
| - 'aqi_forecast': (batch, 72) hourly AQI predictions |
| - 'shared_features': (batch, 256) shared representation |
| """ |
| |
| cnn_features = self.cnn_branch(image) |
| lstm_features = self.lstm_branch(timeseries) |
|
|
| |
| fused = torch.cat([cnn_features, lstm_features], dim=1) |
| shared = self.shared_fc(fused) |
|
|
| |
| risk_flat = self.risk_head(shared) |
| risk_map = risk_flat.view(-1, PATCH_SIZE, PATCH_SIZE) |
|
|
| |
| aqi_input = torch.cat([shared, lstm_features], dim=1) |
| aqi_forecast = self.aqi_head(aqi_input) |
|
|
| return { |
| "risk_map": risk_map, |
| "aqi_forecast": aqi_forecast, |
| "shared_features": shared, |
| } |
|
|
| def forward_cnn_only(self, image: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass using only the CNN branch (ablation Config A). |
| Uses zero-filled LSTM features. |
| """ |
| cnn_features = self.cnn_branch(image) |
| lstm_features = torch.zeros( |
| image.size(0), LSTM_FEATURE_DIM, device=image.device |
| ) |
|
|
| fused = torch.cat([cnn_features, lstm_features], dim=1) |
| shared = self.shared_fc(fused) |
|
|
| risk_flat = self.risk_head(shared) |
| risk_map = risk_flat.view(-1, PATCH_SIZE, PATCH_SIZE) |
|
|
| return {"risk_map": risk_map, "shared_features": shared} |
|
|
| def forward_lstm_only(self, timeseries: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass using only the LSTM branch (ablation Config B). |
| Uses zero-filled CNN features. |
| """ |
| cnn_features = torch.zeros( |
| timeseries.size(0), CNN_FEATURE_DIM, device=timeseries.device |
| ) |
| lstm_features = self.lstm_branch(timeseries) |
|
|
| fused = torch.cat([cnn_features, lstm_features], dim=1) |
| shared = self.shared_fc(fused) |
|
|
| aqi_input = torch.cat([shared, lstm_features], dim=1) |
| aqi_forecast = self.aqi_head(aqi_input) |
|
|
| return {"aqi_forecast": aqi_forecast, "shared_features": shared} |
|
|
| def get_risk_score(self, risk_map: torch.Tensor) -> Tuple[float, str]: |
| """ |
| Compute overall risk score and level from a risk heatmap. |
| |
| Args: |
| risk_map: (128, 128) risk probability map. |
| |
| Returns: |
| Tuple of (risk_score, risk_level). |
| """ |
| score = float(risk_map.mean()) |
|
|
| if score < 0.25: |
| level = "Low" |
| elif score < 0.50: |
| level = "Medium" |
| elif score < 0.75: |
| level = "High" |
| else: |
| level = "Extreme" |
|
|
| return score, level |
|
|
|
|
| def load_model( |
| checkpoint_path: Optional[str] = None, |
| device: str = DEVICE, |
| ) -> MultiTaskFusionModel: |
| """ |
| Load a MultiTaskFusionModel, optionally from a checkpoint. |
| |
| Supports both full-precision and dynamically quantized checkpoints. |
| For quantized models, the architecture is first initialized, then |
| dynamic quantization is applied before loading the state dict. |
| |
| Args: |
| checkpoint_path: Path to .pth checkpoint file. |
| device: Device to load model onto. |
| |
| Returns: |
| Loaded model in eval mode. |
| """ |
| model = MultiTaskFusionModel(pretrained_cnn=True) |
|
|
| if checkpoint_path and Path(checkpoint_path).exists(): |
| logger.info(f"Loading checkpoint: {checkpoint_path}") |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
|
|
| if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: |
| state_dict = checkpoint["model_state_dict"] |
| else: |
| state_dict = checkpoint |
|
|
| |
| quantized_keys = [k for k in state_dict.keys() if "_packed_params" in k or "scale" in k] |
| is_quantized = len(quantized_keys) > 5 |
|
|
| if is_quantized: |
| logger.info("Detected quantized checkpoint — applying dynamic quantization first...") |
| model = torch.quantization.quantize_dynamic( |
| model, |
| {torch.nn.Linear, torch.nn.LSTM}, |
| dtype=torch.qint8, |
| ) |
| model.load_state_dict(state_dict, strict=False) |
| else: |
| model.load_state_dict(state_dict) |
|
|
| logger.info("Checkpoint loaded successfully.") |
| else: |
| logger.info("No checkpoint provided — using fresh model with pretrained CNN backbone.") |
|
|
| model = model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| from pathlib import Path |
|
|
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.INFO) |
|
|
| model = MultiTaskFusionModel(pretrained_cnn=False) |
|
|
| |
| img = torch.randn(2, 4, 128, 128) |
| ts = torch.randn(2, 7, 6) |
|
|
| output = model(img, ts) |
| print(f"\nFull MTL forward pass:") |
| print(f" Risk map: {output['risk_map'].shape}") |
| print(f" AQI forecast: {output['aqi_forecast'].shape}") |
| print(f" Shared feat: {output['shared_features'].shape}") |
|
|
| |
| cnn_out = model.forward_cnn_only(img) |
| print(f"\nCNN-only (Config A):") |
| print(f" Risk map: {cnn_out['risk_map'].shape}") |
|
|
| lstm_out = model.forward_lstm_only(ts) |
| print(f"\nLSTM-only (Config B):") |
| print(f" AQI forecast: {lstm_out['aqi_forecast'].shape}") |
|
|
| |
| score, level = model.get_risk_score(output["risk_map"][0]) |
| print(f"\nRisk score: {score:.3f} ({level})") |
|
|