kishore-9's picture
Add road scene classifier app
9466fff
"""
src/explain.py
Grad-CAM wrapper for multi-label inference.
Why Grad-CAM on the last conv block?
The last conv block (model.features[-1]) is the deepest layer that still
retains spatial information before global average pooling collapses it to
a vector. Earlier layers are too fine-grained and noisy; later layers have
no spatial dimension to show.
For multi-label, each output neuron has its own gradient path back through
the network, so we get a separate heatmap per predicted label — not a single
heatmap for the "winning" class.
Public API:
explainer = GradCAMExplainer(model)
overlay_img = explainer.explain(img_pil, label_name="rainy")
overlays = explainer.explain_predicted(img_pil, thresholds)
# CLI sanity check (saves 20 overlays to experiments/gradcam_samples/)
python -m src.explain --checkpoint <path> --split val --n 20
"""
import argparse
import logging
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchvision import transforms
from src.config import IMAGE_SIZE, IMAGENET_MEAN, IMAGENET_STD, LABELS
from src.dataset import BDDMultiLabelDataset, get_transforms
from src.evaluate import load_thresholds
from src.model import build_model
from src.utils import get_device
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
log = logging.getLogger(__name__)
_PREPROCESS = transforms.Compose([
transforms.Resize(int(IMAGE_SIZE * 1.1)),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
def _to_tensor(img_pil: Image.Image) -> torch.Tensor:
"""PIL image → (1, 3, H, W) float tensor, normalised."""
return _PREPROCESS(img_pil.convert("RGB")).unsqueeze(0)
def _to_rgb_array(img_pil: Image.Image) -> np.ndarray:
"""PIL image → float32 (H, W, 3) in [0, 1] for show_cam_on_image."""
img = img_pil.convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
return np.float32(np.array(img)) / 255.0
class GradCAMExplainer:
"""
Wraps pytorch-grad-cam for multi-label EfficientNet-B0.
Usage:
explainer = GradCAMExplainer(model, device)
overlay = explainer.explain(img_pil, "rainy") # PIL image
all_overlays = explainer.explain_predicted(img_pil, thresholds)
"""
def __init__(self, model: torch.nn.Module, device: torch.device | None = None):
self.device = device or get_device()
self.model = model.to(self.device).eval()
# Target the last conv block; this is where spatial info is preserved
target_layers = [self.model.features[-1]]
self.cam = GradCAM(model=self.model, target_layers=target_layers)
def explain(self, img_pil: Image.Image, label_name: str) -> Image.Image:
"""
Generate a Grad-CAM overlay for a single label.
Args:
img_pil: Input PIL image (any size; will be resized internally)
label_name: One of the strings in LABELS
Returns:
PIL image with the heatmap overlaid on the resized input
"""
if label_name not in LABELS:
raise ValueError(f"Unknown label '{label_name}'. Must be one of: {LABELS}")
label_idx = LABELS.index(label_name)
input_tensor = _to_tensor(img_pil).to(self.device)
rgb_array = _to_rgb_array(img_pil)
targets = [ClassifierOutputTarget(label_idx)]
grayscale_cam = self.cam(input_tensor=input_tensor, targets=targets)
# grayscale_cam shape: (1, H, W) — take the first (and only) batch item
overlay = show_cam_on_image(rgb_array, grayscale_cam[0], use_rgb=True)
return Image.fromarray(overlay)
@torch.no_grad()
def get_probs(self, img_pil: Image.Image) -> dict[str, float]:
"""Return post-sigmoid probabilities for all labels."""
input_tensor = _to_tensor(img_pil).to(self.device)
logits = self.model(input_tensor)
probs = torch.sigmoid(logits).squeeze().cpu().tolist()
return {label: round(p, 4) for label, p in zip(LABELS, probs)}
def explain_predicted(self, img_pil: Image.Image,
thresholds: dict[str, float] | None = None
) -> dict[str, Image.Image]:
"""
Run inference, then generate Grad-CAM for every label that exceeds
its threshold. Returns {label_name: overlay_PIL_image}.
"""
if thresholds is None:
thresholds = load_thresholds()
probs = self.get_probs(img_pil)
predicted = [label for label, p in probs.items() if p >= thresholds.get(label, 0.5)]
overlays = {}
for label in predicted:
overlays[label] = self.explain(img_pil, label)
return overlays
# ---------------------------------------------------------------------------
# CLI sanity check
# ---------------------------------------------------------------------------
def _run_sanity_check(checkpoint: str, split: str, n: int) -> None:
"""
Save n Grad-CAM overlays for randomly sampled images from `split`.
Used to visually verify that heatmaps look sensible before the API uses them.
"""
import random
device = get_device()
model = build_model().to(device)
model.load_state_dict(torch.load(checkpoint, map_location=device))
explainer = GradCAMExplainer(model, device)
thresholds = load_thresholds()
ds = BDDMultiLabelDataset(split)
indices = random.sample(range(len(ds)), min(n, len(ds)))
out_dir = Path("experiments/gradcam_samples")
out_dir.mkdir(parents=True, exist_ok=True)
for rank, idx in enumerate(indices):
row = ds.df.iloc[idx]
img_pil = Image.open(row["image_path"]).convert("RGB")
probs = explainer.get_probs(img_pil)
predicted = [l for l, p in probs.items() if p >= thresholds.get(l, 0.5)]
if not predicted:
log.info("Sample %d: no labels above threshold, skipping", idx)
continue
# Overlay for the highest-confidence predicted label
top_label = max(predicted, key=lambda l: probs[l])
overlay = explainer.explain(img_pil, top_label)
# Side-by-side: original | overlay
combined = Image.new("RGB", (IMAGE_SIZE * 2 + 4, IMAGE_SIZE), color=(40, 40, 40))
combined.paste(img_pil.resize((IMAGE_SIZE, IMAGE_SIZE)), (0, 0))
combined.paste(overlay, (IMAGE_SIZE + 4, 0))
fname = out_dir / f"sample_{rank:03d}_{top_label}.png"
combined.save(fname)
log.info("Saved %s | predicted: %s", fname.name,
", ".join(f"{l}={probs[l]:.2f}" for l in predicted))
log.info("Saved %d Grad-CAM samples to %s", len(indices), out_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Grad-CAM sanity check")
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--split", default="val", choices=["train", "val", "test"])
parser.add_argument("--n", type=int, default=20, help="Number of samples to visualise")
args = parser.parse_args()
_run_sanity_check(args.checkpoint, args.split, args.n)