krupal02's picture
Deploy Multi-Hazard Warning System - MTL model for wildfire risk + AQI forecasting
d5b0af1
Raw
History Blame Contribute Delete
10.6 kB
"""
fusion_model.py — Multi-Task Learning fusion model with dual output heads.
This is the core MTL architecture. A single model with:
- CNN branch (ResNet-50) for spatial features
- LSTM branch for temporal features
- Shared FC layers fusing both representations
- Task 1 Head: Wildfire risk heatmap (128×128, sigmoid)
- Task 2 Head: AQI forecast (72 hourly values, linear)
The shared backbone learns representations useful for BOTH tasks,
enabling knowledge transfer between wildfire risk and air quality.
"""
import logging
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from src.models.cnn_branch import CNNBranch
from src.models.lstm_branch import LSTMBranch
from src.training.config import (
CNN_FEATURE_DIM, LSTM_FEATURE_DIM, FUSION_DIM,
SHARED_FC_DIMS, DROPOUT_RATE,
PATCH_SIZE, AQI_FORECAST_HOURS, IMAGE_CHANNELS, TIMESERIES_FEATURES,
DEVICE,
)
logger = logging.getLogger(__name__)
class MultiTaskFusionModel(nn.Module):
"""
Multi-Task Learning model combining CNN and LSTM branches.
Architecture:
CNN Branch: (batch, 4, 128, 128) → (batch, 2048)
LSTM Branch: (batch, 7, 6) → (batch, 256)
Fusion: concat → 2304 → 512 → 256 (shared)
Head 1: 256 → 128×128 (wildfire risk heatmap)
Head 2: 256 → 72 (AQI hourly forecast)
This is explicitly a Multi-Task Learning (MTL) model — one model,
two tasks, shared backbone, task-specific heads.
"""
def __init__(
self,
pretrained_cnn: bool = True,
freeze_cnn_early: bool = False,
):
super().__init__()
# ---- Branch 1: CNN Spatial Encoder ----
self.cnn_branch = CNNBranch(
in_channels=IMAGE_CHANNELS,
out_features=CNN_FEATURE_DIM,
pretrained=pretrained_cnn,
freeze_early=freeze_cnn_early,
)
# ---- Branch 2: LSTM Temporal Encoder ----
self.lstm_branch = LSTMBranch(
input_size=TIMESERIES_FEATURES,
)
# ---- Shared Fusion Layers ----
# Concatenate: CNN_FEATURE_DIM + LSTM_FEATURE_DIM = 2048 + 256 = 2304
fusion_layers = []
prev_dim = FUSION_DIM
for dim in SHARED_FC_DIMS:
fusion_layers.extend([
nn.Linear(prev_dim, dim),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Dropout(DROPOUT_RATE),
])
prev_dim = dim
self.shared_fc = nn.Sequential(*fusion_layers)
# ---- Task 1 Head: Wildfire Risk Heatmap ----
# 256 → 128×128 spatial map
self.risk_head = nn.Sequential(
nn.Linear(SHARED_FC_DIMS[-1], 512),
nn.ReLU(inplace=True),
nn.Dropout(DROPOUT_RATE),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, PATCH_SIZE * PATCH_SIZE),
nn.Sigmoid(), # Output probabilities [0, 1]
)
# ---- Task 2 Head: AQI Forecast ----
# (shared_dim + lstm_dim) → 72 hourly values
# Adding skip connection from lstm_features directly to prevent CNN saturation
self.aqi_head = nn.Sequential(
nn.Linear(SHARED_FC_DIMS[-1] + LSTM_FEATURE_DIM, 256),
nn.ReLU(inplace=True),
nn.Dropout(DROPOUT_RATE),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Linear(128, AQI_FORECAST_HOURS),
nn.Sigmoid(), # Bound output to [0, 1]; unscaled by ×500 at inference
)
self._log_architecture()
def _log_architecture(self):
"""Log model architecture summary."""
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
logger.info(
f"MultiTaskFusionModel initialized:\n"
f" CNN branch: {IMAGE_CHANNELS}ch → {CNN_FEATURE_DIM}d\n"
f" LSTM branch: {TIMESERIES_FEATURES}feat → {LSTM_FEATURE_DIM}d\n"
f" Fusion: {FUSION_DIM}d → {SHARED_FC_DIMS}d\n"
f" Risk head: {SHARED_FC_DIMS[-1]}d → {PATCH_SIZE}×{PATCH_SIZE}\n"
f" AQI head: {SHARED_FC_DIMS[-1]}d → {AQI_FORECAST_HOURS}h\n"
f" Parameters: {total:,} total, {trainable:,} trainable"
)
def forward(
self,
image: torch.Tensor,
timeseries: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Forward pass through the full MTL model.
Args:
image: (batch, 4, 128, 128) satellite image patch.
timeseries: (batch, 7, 6) weather + AQI sequence.
Returns:
Dict with:
- 'risk_map': (batch, 128, 128) wildfire risk heatmap
- 'aqi_forecast': (batch, 72) hourly AQI predictions
- 'shared_features': (batch, 256) shared representation
"""
# Extract features from both branches
cnn_features = self.cnn_branch(image) # (batch, 2048)
lstm_features = self.lstm_branch(timeseries) # (batch, 256)
# Fuse representations
fused = torch.cat([cnn_features, lstm_features], dim=1) # (batch, 2304)
shared = self.shared_fc(fused) # (batch, 256)
# Task-specific heads
risk_flat = self.risk_head(shared) # (batch, 128*128)
risk_map = risk_flat.view(-1, PATCH_SIZE, PATCH_SIZE) # (batch, 128, 128)
# Skip connection: guarantee temporal structure guides regression.
aqi_input = torch.cat([shared, lstm_features], dim=1)
aqi_forecast = self.aqi_head(aqi_input) # (batch, 72)
return {
"risk_map": risk_map,
"aqi_forecast": aqi_forecast,
"shared_features": shared,
}
def forward_cnn_only(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward pass using only the CNN branch (ablation Config A).
Uses zero-filled LSTM features.
"""
cnn_features = self.cnn_branch(image)
lstm_features = torch.zeros(
image.size(0), LSTM_FEATURE_DIM, device=image.device
)
fused = torch.cat([cnn_features, lstm_features], dim=1)
shared = self.shared_fc(fused)
risk_flat = self.risk_head(shared)
risk_map = risk_flat.view(-1, PATCH_SIZE, PATCH_SIZE)
return {"risk_map": risk_map, "shared_features": shared}
def forward_lstm_only(self, timeseries: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward pass using only the LSTM branch (ablation Config B).
Uses zero-filled CNN features.
"""
cnn_features = torch.zeros(
timeseries.size(0), CNN_FEATURE_DIM, device=timeseries.device
)
lstm_features = self.lstm_branch(timeseries)
fused = torch.cat([cnn_features, lstm_features], dim=1)
shared = self.shared_fc(fused)
aqi_input = torch.cat([shared, lstm_features], dim=1)
aqi_forecast = self.aqi_head(aqi_input)
return {"aqi_forecast": aqi_forecast, "shared_features": shared}
def get_risk_score(self, risk_map: torch.Tensor) -> Tuple[float, str]:
"""
Compute overall risk score and level from a risk heatmap.
Args:
risk_map: (128, 128) risk probability map.
Returns:
Tuple of (risk_score, risk_level).
"""
score = float(risk_map.mean())
if score < 0.25:
level = "Low"
elif score < 0.50:
level = "Medium"
elif score < 0.75:
level = "High"
else:
level = "Extreme"
return score, level
def load_model(
checkpoint_path: Optional[str] = None,
device: str = DEVICE,
) -> MultiTaskFusionModel:
"""
Load a MultiTaskFusionModel, optionally from a checkpoint.
Supports both full-precision and dynamically quantized checkpoints.
For quantized models, the architecture is first initialized, then
dynamic quantization is applied before loading the state dict.
Args:
checkpoint_path: Path to .pth checkpoint file.
device: Device to load model onto.
Returns:
Loaded model in eval mode.
"""
model = MultiTaskFusionModel(pretrained_cnn=True)
if checkpoint_path and Path(checkpoint_path).exists():
logger.info(f"Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
# Detect if checkpoint is from a quantized model
quantized_keys = [k for k in state_dict.keys() if "_packed_params" in k or "scale" in k]
is_quantized = len(quantized_keys) > 5 # Quantized models have many such keys
if is_quantized:
logger.info("Detected quantized checkpoint — applying dynamic quantization first...")
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear, torch.nn.LSTM},
dtype=torch.qint8,
)
model.load_state_dict(state_dict, strict=False)
else:
model.load_state_dict(state_dict)
logger.info("Checkpoint loaded successfully.")
else:
logger.info("No checkpoint provided — using fresh model with pretrained CNN backbone.")
model = model.to(device)
model.eval()
return model
from pathlib import Path
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
model = MultiTaskFusionModel(pretrained_cnn=False)
# Test full forward pass
img = torch.randn(2, 4, 128, 128)
ts = torch.randn(2, 7, 6)
output = model(img, ts)
print(f"\nFull MTL forward pass:")
print(f" Risk map: {output['risk_map'].shape}")
print(f" AQI forecast: {output['aqi_forecast'].shape}")
print(f" Shared feat: {output['shared_features'].shape}")
# Test ablation modes
cnn_out = model.forward_cnn_only(img)
print(f"\nCNN-only (Config A):")
print(f" Risk map: {cnn_out['risk_map'].shape}")
lstm_out = model.forward_lstm_only(ts)
print(f"\nLSTM-only (Config B):")
print(f" AQI forecast: {lstm_out['aqi_forecast'].shape}")
# Risk score
score, level = model.get_risk_score(output["risk_map"][0])
print(f"\nRisk score: {score:.3f} ({level})")