| """ |
| src/explain.py |
| |
| Grad-CAM wrapper for multi-label inference. |
| |
| Why Grad-CAM on the last conv block? |
| The last conv block (model.features[-1]) is the deepest layer that still |
| retains spatial information before global average pooling collapses it to |
| a vector. Earlier layers are too fine-grained and noisy; later layers have |
| no spatial dimension to show. |
| |
| For multi-label, each output neuron has its own gradient path back through |
| the network, so we get a separate heatmap per predicted label — not a single |
| heatmap for the "winning" class. |
| |
| Public API: |
| explainer = GradCAMExplainer(model) |
| overlay_img = explainer.explain(img_pil, label_name="rainy") |
| overlays = explainer.explain_predicted(img_pil, thresholds) |
| |
| # CLI sanity check (saves 20 overlays to experiments/gradcam_samples/) |
| python -m src.explain --checkpoint <path> --split val --n 20 |
| """ |
|
|
| import argparse |
| import logging |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from pytorch_grad_cam import GradCAM |
| from pytorch_grad_cam.utils.image import show_cam_on_image |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
| from torchvision import transforms |
|
|
| from src.config import IMAGE_SIZE, IMAGENET_MEAN, IMAGENET_STD, LABELS |
| from src.dataset import BDDMultiLabelDataset, get_transforms |
| from src.evaluate import load_thresholds |
| from src.model import build_model |
| from src.utils import get_device |
|
|
| logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") |
| log = logging.getLogger(__name__) |
|
|
| _PREPROCESS = transforms.Compose([ |
| transforms.Resize(int(IMAGE_SIZE * 1.1)), |
| transforms.CenterCrop(IMAGE_SIZE), |
| transforms.ToTensor(), |
| transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), |
| ]) |
|
|
|
|
| def _to_tensor(img_pil: Image.Image) -> torch.Tensor: |
| """PIL image → (1, 3, H, W) float tensor, normalised.""" |
| return _PREPROCESS(img_pil.convert("RGB")).unsqueeze(0) |
|
|
|
|
| def _to_rgb_array(img_pil: Image.Image) -> np.ndarray: |
| """PIL image → float32 (H, W, 3) in [0, 1] for show_cam_on_image.""" |
| img = img_pil.convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE)) |
| return np.float32(np.array(img)) / 255.0 |
|
|
|
|
| class GradCAMExplainer: |
| """ |
| Wraps pytorch-grad-cam for multi-label EfficientNet-B0. |
| |
| Usage: |
| explainer = GradCAMExplainer(model, device) |
| overlay = explainer.explain(img_pil, "rainy") # PIL image |
| all_overlays = explainer.explain_predicted(img_pil, thresholds) |
| """ |
|
|
| def __init__(self, model: torch.nn.Module, device: torch.device | None = None): |
| self.device = device or get_device() |
| self.model = model.to(self.device).eval() |
| |
| target_layers = [self.model.features[-1]] |
| self.cam = GradCAM(model=self.model, target_layers=target_layers) |
|
|
| def explain(self, img_pil: Image.Image, label_name: str) -> Image.Image: |
| """ |
| Generate a Grad-CAM overlay for a single label. |
| |
| Args: |
| img_pil: Input PIL image (any size; will be resized internally) |
| label_name: One of the strings in LABELS |
| |
| Returns: |
| PIL image with the heatmap overlaid on the resized input |
| """ |
| if label_name not in LABELS: |
| raise ValueError(f"Unknown label '{label_name}'. Must be one of: {LABELS}") |
|
|
| label_idx = LABELS.index(label_name) |
| input_tensor = _to_tensor(img_pil).to(self.device) |
| rgb_array = _to_rgb_array(img_pil) |
|
|
| targets = [ClassifierOutputTarget(label_idx)] |
| grayscale_cam = self.cam(input_tensor=input_tensor, targets=targets) |
| |
| overlay = show_cam_on_image(rgb_array, grayscale_cam[0], use_rgb=True) |
| return Image.fromarray(overlay) |
|
|
| @torch.no_grad() |
| def get_probs(self, img_pil: Image.Image) -> dict[str, float]: |
| """Return post-sigmoid probabilities for all labels.""" |
| input_tensor = _to_tensor(img_pil).to(self.device) |
| logits = self.model(input_tensor) |
| probs = torch.sigmoid(logits).squeeze().cpu().tolist() |
| return {label: round(p, 4) for label, p in zip(LABELS, probs)} |
|
|
| def explain_predicted(self, img_pil: Image.Image, |
| thresholds: dict[str, float] | None = None |
| ) -> dict[str, Image.Image]: |
| """ |
| Run inference, then generate Grad-CAM for every label that exceeds |
| its threshold. Returns {label_name: overlay_PIL_image}. |
| """ |
| if thresholds is None: |
| thresholds = load_thresholds() |
|
|
| probs = self.get_probs(img_pil) |
| predicted = [label for label, p in probs.items() if p >= thresholds.get(label, 0.5)] |
|
|
| overlays = {} |
| for label in predicted: |
| overlays[label] = self.explain(img_pil, label) |
| return overlays |
|
|
|
|
| |
| |
| |
|
|
| def _run_sanity_check(checkpoint: str, split: str, n: int) -> None: |
| """ |
| Save n Grad-CAM overlays for randomly sampled images from `split`. |
| Used to visually verify that heatmaps look sensible before the API uses them. |
| """ |
| import random |
| device = get_device() |
| model = build_model().to(device) |
| model.load_state_dict(torch.load(checkpoint, map_location=device)) |
|
|
| explainer = GradCAMExplainer(model, device) |
| thresholds = load_thresholds() |
|
|
| ds = BDDMultiLabelDataset(split) |
| indices = random.sample(range(len(ds)), min(n, len(ds))) |
|
|
| out_dir = Path("experiments/gradcam_samples") |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| for rank, idx in enumerate(indices): |
| row = ds.df.iloc[idx] |
| img_pil = Image.open(row["image_path"]).convert("RGB") |
| probs = explainer.get_probs(img_pil) |
| predicted = [l for l, p in probs.items() if p >= thresholds.get(l, 0.5)] |
|
|
| if not predicted: |
| log.info("Sample %d: no labels above threshold, skipping", idx) |
| continue |
|
|
| |
| top_label = max(predicted, key=lambda l: probs[l]) |
| overlay = explainer.explain(img_pil, top_label) |
|
|
| |
| combined = Image.new("RGB", (IMAGE_SIZE * 2 + 4, IMAGE_SIZE), color=(40, 40, 40)) |
| combined.paste(img_pil.resize((IMAGE_SIZE, IMAGE_SIZE)), (0, 0)) |
| combined.paste(overlay, (IMAGE_SIZE + 4, 0)) |
|
|
| fname = out_dir / f"sample_{rank:03d}_{top_label}.png" |
| combined.save(fname) |
| log.info("Saved %s | predicted: %s", fname.name, |
| ", ".join(f"{l}={probs[l]:.2f}" for l in predicted)) |
|
|
| log.info("Saved %d Grad-CAM samples to %s", len(indices), out_dir) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Grad-CAM sanity check") |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--split", default="val", choices=["train", "val", "test"]) |
| parser.add_argument("--n", type=int, default=20, help="Number of samples to visualise") |
| args = parser.parse_args() |
| _run_sanity_check(args.checkpoint, args.split, args.n) |
|
|