Spaces:
Sleeping
Sleeping
| """ | |
| GradCAM Explainer — See where the CNN looks | |
| Course: 215 AI Safety ch8 | |
| """ | |
| import json | |
| import urllib.request | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| import torchvision.transforms as T | |
| import gradio as gr | |
| from PIL import Image | |
| # --------------------------------------------------------------------------- | |
| # Models | |
| # --------------------------------------------------------------------------- | |
| device = torch.device("cpu") | |
| MODELS = { | |
| "ResNet-50": models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1), | |
| } | |
| for m in MODELS.values(): | |
| m.eval().to(device) | |
| # Target layers for GradCAM | |
| TARGET_LAYERS = { | |
| "ResNet-50": "layer4", | |
| } | |
| preprocess = T.Compose([ | |
| T.Resize(256), | |
| T.CenterCrop(224), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # ImageNet labels | |
| LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" | |
| try: | |
| with urllib.request.urlopen(LABELS_URL) as resp: | |
| LABELS = json.loads(resp.read().decode()) | |
| except Exception: | |
| LABELS = [str(i) for i in range(1000)] | |
| # --------------------------------------------------------------------------- | |
| # GradCAM implementation | |
| # --------------------------------------------------------------------------- | |
| class GradCAM: | |
| def __init__(self, model, target_layer_name): | |
| self.model = model | |
| self.gradients = None | |
| self.activations = None | |
| target_layer = dict(model.named_modules())[target_layer_name] | |
| target_layer.register_forward_hook(self._save_activation) | |
| target_layer.register_full_backward_hook(self._save_gradient) | |
| def _save_activation(self, module, input, output): | |
| self.activations = output.detach() | |
| def _save_gradient(self, module, grad_input, grad_output): | |
| self.gradients = grad_output[0].detach() | |
| def generate(self, input_tensor, target_class=None): | |
| self.model.zero_grad() | |
| output = self.model(input_tensor) | |
| if target_class is None: | |
| target_class = output.argmax(1).item() | |
| one_hot = torch.zeros_like(output) | |
| one_hot[0, target_class] = 1 | |
| output.backward(gradient=one_hot) | |
| weights = self.gradients.mean(dim=[2, 3], keepdim=True) | |
| cam = (weights * self.activations).sum(dim=1, keepdim=True) | |
| cam = F.relu(cam) | |
| cam = F.interpolate(cam, size=(224, 224), mode="bilinear", align_corners=False) | |
| cam = cam.squeeze() | |
| if cam.max() > 0: | |
| cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) | |
| return cam.numpy(), target_class | |
| # Build GradCAM instances | |
| gradcams = {name: GradCAM(m, TARGET_LAYERS[name]) for name, m in MODELS.items()} | |
| def get_top5(logits): | |
| probs = F.softmax(logits, dim=1)[0] | |
| top5 = torch.topk(probs, 5) | |
| return {LABELS[idx]: float(prob) for prob, idx in zip(top5.values, top5.indices)} | |
| # --------------------------------------------------------------------------- | |
| # Main function | |
| # --------------------------------------------------------------------------- | |
| def explain(image: Image.Image, model_name: str, target_class_name: str): | |
| if image is None: | |
| return None, None, None, {} | |
| img = image.convert("RGB") | |
| inp = preprocess(img).unsqueeze(0).to(device) | |
| model = MODELS[model_name] | |
| gradcam = gradcams[model_name] | |
| # Forward pass for top-5 | |
| with torch.no_grad(): | |
| logits = model(inp) | |
| top5 = get_top5(logits) | |
| # Determine target class | |
| if target_class_name and target_class_name in LABELS: | |
| target_idx = LABELS.index(target_class_name) | |
| else: | |
| target_idx = None # use argmax | |
| # Generate GradCAM | |
| cam, used_class = gradcam.generate(inp, target_idx) | |
| # Prepare display images | |
| display_img = img.resize((224, 224)) | |
| img_np = np.array(display_img) | |
| # Heatmap | |
| heatmap = cv2.applyColorMap((cam * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| # Overlay | |
| overlay = (img_np * 0.5 + heatmap * 0.5).astype(np.uint8) | |
| return img_np, heatmap, overlay, top5 | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="GradCAM Explainer") as demo: | |
| gr.Markdown( | |
| "# GradCAM Explainer\n" | |
| "Upload an image to visualize which regions a CNN focuses on for its prediction.\n" | |
| "*Course: 215 AI Safety — Explainability*" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Upload Image") | |
| model_choice = gr.Dropdown( | |
| list(MODELS.keys()), value="ResNet-50", label="Model" | |
| ) | |
| target_class = gr.Textbox( | |
| label="Target Class (optional)", | |
| placeholder="Leave empty for top prediction", | |
| ) | |
| run_btn = gr.Button("Generate GradCAM", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| orig_out = gr.Image(label="Original (224x224)") | |
| heat_out = gr.Image(label="GradCAM Heatmap") | |
| over_out = gr.Image(label="Overlay") | |
| top5_out = gr.Label(num_top_classes=5, label="Top-5 Predictions") | |
| run_btn.click( | |
| fn=explain, | |
| inputs=[input_image, model_choice, target_class], | |
| outputs=[orig_out, heat_out, over_out, top5_out], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/cat.jpg", "ResNet-50", ""], | |
| ["examples/dog.jpg", "ResNet-50", ""], | |
| ], | |
| inputs=[input_image, model_choice, target_class], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |