""" src/explain.py Grad-CAM wrapper for multi-label inference. Why Grad-CAM on the last conv block? The last conv block (model.features[-1]) is the deepest layer that still retains spatial information before global average pooling collapses it to a vector. Earlier layers are too fine-grained and noisy; later layers have no spatial dimension to show. For multi-label, each output neuron has its own gradient path back through the network, so we get a separate heatmap per predicted label — not a single heatmap for the "winning" class. Public API: explainer = GradCAMExplainer(model) overlay_img = explainer.explain(img_pil, label_name="rainy") overlays = explainer.explain_predicted(img_pil, thresholds) # CLI sanity check (saves 20 overlays to experiments/gradcam_samples/) python -m src.explain --checkpoint --split val --n 20 """ import argparse import logging from pathlib import Path import numpy as np import torch from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from torchvision import transforms from src.config import IMAGE_SIZE, IMAGENET_MEAN, IMAGENET_STD, LABELS from src.dataset import BDDMultiLabelDataset, get_transforms from src.evaluate import load_thresholds from src.model import build_model from src.utils import get_device logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") log = logging.getLogger(__name__) _PREPROCESS = transforms.Compose([ transforms.Resize(int(IMAGE_SIZE * 1.1)), transforms.CenterCrop(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), ]) def _to_tensor(img_pil: Image.Image) -> torch.Tensor: """PIL image → (1, 3, H, W) float tensor, normalised.""" return _PREPROCESS(img_pil.convert("RGB")).unsqueeze(0) def _to_rgb_array(img_pil: Image.Image) -> np.ndarray: """PIL image → float32 (H, W, 3) in [0, 1] for show_cam_on_image.""" img = img_pil.convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE)) return np.float32(np.array(img)) / 255.0 class GradCAMExplainer: """ Wraps pytorch-grad-cam for multi-label EfficientNet-B0. Usage: explainer = GradCAMExplainer(model, device) overlay = explainer.explain(img_pil, "rainy") # PIL image all_overlays = explainer.explain_predicted(img_pil, thresholds) """ def __init__(self, model: torch.nn.Module, device: torch.device | None = None): self.device = device or get_device() self.model = model.to(self.device).eval() # Target the last conv block; this is where spatial info is preserved target_layers = [self.model.features[-1]] self.cam = GradCAM(model=self.model, target_layers=target_layers) def explain(self, img_pil: Image.Image, label_name: str) -> Image.Image: """ Generate a Grad-CAM overlay for a single label. Args: img_pil: Input PIL image (any size; will be resized internally) label_name: One of the strings in LABELS Returns: PIL image with the heatmap overlaid on the resized input """ if label_name not in LABELS: raise ValueError(f"Unknown label '{label_name}'. Must be one of: {LABELS}") label_idx = LABELS.index(label_name) input_tensor = _to_tensor(img_pil).to(self.device) rgb_array = _to_rgb_array(img_pil) targets = [ClassifierOutputTarget(label_idx)] grayscale_cam = self.cam(input_tensor=input_tensor, targets=targets) # grayscale_cam shape: (1, H, W) — take the first (and only) batch item overlay = show_cam_on_image(rgb_array, grayscale_cam[0], use_rgb=True) return Image.fromarray(overlay) @torch.no_grad() def get_probs(self, img_pil: Image.Image) -> dict[str, float]: """Return post-sigmoid probabilities for all labels.""" input_tensor = _to_tensor(img_pil).to(self.device) logits = self.model(input_tensor) probs = torch.sigmoid(logits).squeeze().cpu().tolist() return {label: round(p, 4) for label, p in zip(LABELS, probs)} def explain_predicted(self, img_pil: Image.Image, thresholds: dict[str, float] | None = None ) -> dict[str, Image.Image]: """ Run inference, then generate Grad-CAM for every label that exceeds its threshold. Returns {label_name: overlay_PIL_image}. """ if thresholds is None: thresholds = load_thresholds() probs = self.get_probs(img_pil) predicted = [label for label, p in probs.items() if p >= thresholds.get(label, 0.5)] overlays = {} for label in predicted: overlays[label] = self.explain(img_pil, label) return overlays # --------------------------------------------------------------------------- # CLI sanity check # --------------------------------------------------------------------------- def _run_sanity_check(checkpoint: str, split: str, n: int) -> None: """ Save n Grad-CAM overlays for randomly sampled images from `split`. Used to visually verify that heatmaps look sensible before the API uses them. """ import random device = get_device() model = build_model().to(device) model.load_state_dict(torch.load(checkpoint, map_location=device)) explainer = GradCAMExplainer(model, device) thresholds = load_thresholds() ds = BDDMultiLabelDataset(split) indices = random.sample(range(len(ds)), min(n, len(ds))) out_dir = Path("experiments/gradcam_samples") out_dir.mkdir(parents=True, exist_ok=True) for rank, idx in enumerate(indices): row = ds.df.iloc[idx] img_pil = Image.open(row["image_path"]).convert("RGB") probs = explainer.get_probs(img_pil) predicted = [l for l, p in probs.items() if p >= thresholds.get(l, 0.5)] if not predicted: log.info("Sample %d: no labels above threshold, skipping", idx) continue # Overlay for the highest-confidence predicted label top_label = max(predicted, key=lambda l: probs[l]) overlay = explainer.explain(img_pil, top_label) # Side-by-side: original | overlay combined = Image.new("RGB", (IMAGE_SIZE * 2 + 4, IMAGE_SIZE), color=(40, 40, 40)) combined.paste(img_pil.resize((IMAGE_SIZE, IMAGE_SIZE)), (0, 0)) combined.paste(overlay, (IMAGE_SIZE + 4, 0)) fname = out_dir / f"sample_{rank:03d}_{top_label}.png" combined.save(fname) log.info("Saved %s | predicted: %s", fname.name, ", ".join(f"{l}={probs[l]:.2f}" for l in predicted)) log.info("Saved %d Grad-CAM samples to %s", len(indices), out_dir) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Grad-CAM sanity check") parser.add_argument("--checkpoint", required=True) parser.add_argument("--split", default="val", choices=["train", "val", "test"]) parser.add_argument("--n", type=int, default=20, help="Number of samples to visualise") args = parser.parse_args() _run_sanity_check(args.checkpoint, args.split, args.n)