File size: 7,324 Bytes
9466fff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | """
src/explain.py
Grad-CAM wrapper for multi-label inference.
Why Grad-CAM on the last conv block?
The last conv block (model.features[-1]) is the deepest layer that still
retains spatial information before global average pooling collapses it to
a vector. Earlier layers are too fine-grained and noisy; later layers have
no spatial dimension to show.
For multi-label, each output neuron has its own gradient path back through
the network, so we get a separate heatmap per predicted label — not a single
heatmap for the "winning" class.
Public API:
explainer = GradCAMExplainer(model)
overlay_img = explainer.explain(img_pil, label_name="rainy")
overlays = explainer.explain_predicted(img_pil, thresholds)
# CLI sanity check (saves 20 overlays to experiments/gradcam_samples/)
python -m src.explain --checkpoint <path> --split val --n 20
"""
import argparse
import logging
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchvision import transforms
from src.config import IMAGE_SIZE, IMAGENET_MEAN, IMAGENET_STD, LABELS
from src.dataset import BDDMultiLabelDataset, get_transforms
from src.evaluate import load_thresholds
from src.model import build_model
from src.utils import get_device
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
log = logging.getLogger(__name__)
_PREPROCESS = transforms.Compose([
transforms.Resize(int(IMAGE_SIZE * 1.1)),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
def _to_tensor(img_pil: Image.Image) -> torch.Tensor:
"""PIL image → (1, 3, H, W) float tensor, normalised."""
return _PREPROCESS(img_pil.convert("RGB")).unsqueeze(0)
def _to_rgb_array(img_pil: Image.Image) -> np.ndarray:
"""PIL image → float32 (H, W, 3) in [0, 1] for show_cam_on_image."""
img = img_pil.convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
return np.float32(np.array(img)) / 255.0
class GradCAMExplainer:
"""
Wraps pytorch-grad-cam for multi-label EfficientNet-B0.
Usage:
explainer = GradCAMExplainer(model, device)
overlay = explainer.explain(img_pil, "rainy") # PIL image
all_overlays = explainer.explain_predicted(img_pil, thresholds)
"""
def __init__(self, model: torch.nn.Module, device: torch.device | None = None):
self.device = device or get_device()
self.model = model.to(self.device).eval()
# Target the last conv block; this is where spatial info is preserved
target_layers = [self.model.features[-1]]
self.cam = GradCAM(model=self.model, target_layers=target_layers)
def explain(self, img_pil: Image.Image, label_name: str) -> Image.Image:
"""
Generate a Grad-CAM overlay for a single label.
Args:
img_pil: Input PIL image (any size; will be resized internally)
label_name: One of the strings in LABELS
Returns:
PIL image with the heatmap overlaid on the resized input
"""
if label_name not in LABELS:
raise ValueError(f"Unknown label '{label_name}'. Must be one of: {LABELS}")
label_idx = LABELS.index(label_name)
input_tensor = _to_tensor(img_pil).to(self.device)
rgb_array = _to_rgb_array(img_pil)
targets = [ClassifierOutputTarget(label_idx)]
grayscale_cam = self.cam(input_tensor=input_tensor, targets=targets)
# grayscale_cam shape: (1, H, W) — take the first (and only) batch item
overlay = show_cam_on_image(rgb_array, grayscale_cam[0], use_rgb=True)
return Image.fromarray(overlay)
@torch.no_grad()
def get_probs(self, img_pil: Image.Image) -> dict[str, float]:
"""Return post-sigmoid probabilities for all labels."""
input_tensor = _to_tensor(img_pil).to(self.device)
logits = self.model(input_tensor)
probs = torch.sigmoid(logits).squeeze().cpu().tolist()
return {label: round(p, 4) for label, p in zip(LABELS, probs)}
def explain_predicted(self, img_pil: Image.Image,
thresholds: dict[str, float] | None = None
) -> dict[str, Image.Image]:
"""
Run inference, then generate Grad-CAM for every label that exceeds
its threshold. Returns {label_name: overlay_PIL_image}.
"""
if thresholds is None:
thresholds = load_thresholds()
probs = self.get_probs(img_pil)
predicted = [label for label, p in probs.items() if p >= thresholds.get(label, 0.5)]
overlays = {}
for label in predicted:
overlays[label] = self.explain(img_pil, label)
return overlays
# ---------------------------------------------------------------------------
# CLI sanity check
# ---------------------------------------------------------------------------
def _run_sanity_check(checkpoint: str, split: str, n: int) -> None:
"""
Save n Grad-CAM overlays for randomly sampled images from `split`.
Used to visually verify that heatmaps look sensible before the API uses them.
"""
import random
device = get_device()
model = build_model().to(device)
model.load_state_dict(torch.load(checkpoint, map_location=device))
explainer = GradCAMExplainer(model, device)
thresholds = load_thresholds()
ds = BDDMultiLabelDataset(split)
indices = random.sample(range(len(ds)), min(n, len(ds)))
out_dir = Path("experiments/gradcam_samples")
out_dir.mkdir(parents=True, exist_ok=True)
for rank, idx in enumerate(indices):
row = ds.df.iloc[idx]
img_pil = Image.open(row["image_path"]).convert("RGB")
probs = explainer.get_probs(img_pil)
predicted = [l for l, p in probs.items() if p >= thresholds.get(l, 0.5)]
if not predicted:
log.info("Sample %d: no labels above threshold, skipping", idx)
continue
# Overlay for the highest-confidence predicted label
top_label = max(predicted, key=lambda l: probs[l])
overlay = explainer.explain(img_pil, top_label)
# Side-by-side: original | overlay
combined = Image.new("RGB", (IMAGE_SIZE * 2 + 4, IMAGE_SIZE), color=(40, 40, 40))
combined.paste(img_pil.resize((IMAGE_SIZE, IMAGE_SIZE)), (0, 0))
combined.paste(overlay, (IMAGE_SIZE + 4, 0))
fname = out_dir / f"sample_{rank:03d}_{top_label}.png"
combined.save(fname)
log.info("Saved %s | predicted: %s", fname.name,
", ".join(f"{l}={probs[l]:.2f}" for l in predicted))
log.info("Saved %d Grad-CAM samples to %s", len(indices), out_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Grad-CAM sanity check")
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--split", default="val", choices=["train", "val", "test"])
parser.add_argument("--n", type=int, default=20, help="Number of samples to visualise")
args = parser.parse_args()
_run_sanity_check(args.checkpoint, args.split, args.n)
|