""" GradCAM Explainer — See where the CNN looks Course: 215 AI Safety ch8 """ import json import urllib.request import cv2 import numpy as np import torch import torch.nn.functional as F import torchvision.models as models import torchvision.transforms as T import gradio as gr from PIL import Image # --------------------------------------------------------------------------- # Models # --------------------------------------------------------------------------- device = torch.device("cpu") MODELS = { "ResNet-50": models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1), } for m in MODELS.values(): m.eval().to(device) # Target layers for GradCAM TARGET_LAYERS = { "ResNet-50": "layer4", } preprocess = T.Compose([ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ImageNet labels LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" try: with urllib.request.urlopen(LABELS_URL) as resp: LABELS = json.loads(resp.read().decode()) except Exception: LABELS = [str(i) for i in range(1000)] # --------------------------------------------------------------------------- # GradCAM implementation # --------------------------------------------------------------------------- class GradCAM: def __init__(self, model, target_layer_name): self.model = model self.gradients = None self.activations = None target_layer = dict(model.named_modules())[target_layer_name] target_layer.register_forward_hook(self._save_activation) target_layer.register_full_backward_hook(self._save_gradient) def _save_activation(self, module, input, output): self.activations = output.detach() def _save_gradient(self, module, grad_input, grad_output): self.gradients = grad_output[0].detach() def generate(self, input_tensor, target_class=None): self.model.zero_grad() output = self.model(input_tensor) if target_class is None: target_class = output.argmax(1).item() one_hot = torch.zeros_like(output) one_hot[0, target_class] = 1 output.backward(gradient=one_hot) weights = self.gradients.mean(dim=[2, 3], keepdim=True) cam = (weights * self.activations).sum(dim=1, keepdim=True) cam = F.relu(cam) cam = F.interpolate(cam, size=(224, 224), mode="bilinear", align_corners=False) cam = cam.squeeze() if cam.max() > 0: cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) return cam.numpy(), target_class # Build GradCAM instances gradcams = {name: GradCAM(m, TARGET_LAYERS[name]) for name, m in MODELS.items()} def get_top5(logits): probs = F.softmax(logits, dim=1)[0] top5 = torch.topk(probs, 5) return {LABELS[idx]: float(prob) for prob, idx in zip(top5.values, top5.indices)} # --------------------------------------------------------------------------- # Main function # --------------------------------------------------------------------------- def explain(image: Image.Image, model_name: str, target_class_name: str): if image is None: return None, None, None, {} img = image.convert("RGB") inp = preprocess(img).unsqueeze(0).to(device) model = MODELS[model_name] gradcam = gradcams[model_name] # Forward pass for top-5 with torch.no_grad(): logits = model(inp) top5 = get_top5(logits) # Determine target class if target_class_name and target_class_name in LABELS: target_idx = LABELS.index(target_class_name) else: target_idx = None # use argmax # Generate GradCAM cam, used_class = gradcam.generate(inp, target_idx) # Prepare display images display_img = img.resize((224, 224)) img_np = np.array(display_img) # Heatmap heatmap = cv2.applyColorMap((cam * 255).astype(np.uint8), cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) # Overlay overlay = (img_np * 0.5 + heatmap * 0.5).astype(np.uint8) return img_np, heatmap, overlay, top5 # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- with gr.Blocks(title="GradCAM Explainer") as demo: gr.Markdown( "# GradCAM Explainer\n" "Upload an image to visualize which regions a CNN focuses on for its prediction.\n" "*Course: 215 AI Safety — Explainability*" ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload Image") model_choice = gr.Dropdown( list(MODELS.keys()), value="ResNet-50", label="Model" ) target_class = gr.Textbox( label="Target Class (optional)", placeholder="Leave empty for top prediction", ) run_btn = gr.Button("Generate GradCAM", variant="primary") with gr.Column(scale=2): with gr.Row(): orig_out = gr.Image(label="Original (224x224)") heat_out = gr.Image(label="GradCAM Heatmap") over_out = gr.Image(label="Overlay") top5_out = gr.Label(num_top_classes=5, label="Top-5 Predictions") run_btn.click( fn=explain, inputs=[input_image, model_choice, target_class], outputs=[orig_out, heat_out, over_out, top5_out], ) gr.Examples( examples=[ ["examples/cat.jpg", "ResNet-50", ""], ["examples/dog.jpg", "ResNet-50", ""], ], inputs=[input_image, model_choice, target_class], ) if __name__ == "__main__": demo.launch()