jeffliulab's picture
Initial deploy
feac21c verified
"""
GradCAM Explainer — See where the CNN looks
Course: 215 AI Safety ch8
"""
import json
import urllib.request
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
import gradio as gr
from PIL import Image
# ---------------------------------------------------------------------------
# Models
# ---------------------------------------------------------------------------
device = torch.device("cpu")
MODELS = {
"ResNet-50": models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1),
}
for m in MODELS.values():
m.eval().to(device)
# Target layers for GradCAM
TARGET_LAYERS = {
"ResNet-50": "layer4",
}
preprocess = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ImageNet labels
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
try:
with urllib.request.urlopen(LABELS_URL) as resp:
LABELS = json.loads(resp.read().decode())
except Exception:
LABELS = [str(i) for i in range(1000)]
# ---------------------------------------------------------------------------
# GradCAM implementation
# ---------------------------------------------------------------------------
class GradCAM:
def __init__(self, model, target_layer_name):
self.model = model
self.gradients = None
self.activations = None
target_layer = dict(model.named_modules())[target_layer_name]
target_layer.register_forward_hook(self._save_activation)
target_layer.register_full_backward_hook(self._save_gradient)
def _save_activation(self, module, input, output):
self.activations = output.detach()
def _save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate(self, input_tensor, target_class=None):
self.model.zero_grad()
output = self.model(input_tensor)
if target_class is None:
target_class = output.argmax(1).item()
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
output.backward(gradient=one_hot)
weights = self.gradients.mean(dim=[2, 3], keepdim=True)
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = F.relu(cam)
cam = F.interpolate(cam, size=(224, 224), mode="bilinear", align_corners=False)
cam = cam.squeeze()
if cam.max() > 0:
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam.numpy(), target_class
# Build GradCAM instances
gradcams = {name: GradCAM(m, TARGET_LAYERS[name]) for name, m in MODELS.items()}
def get_top5(logits):
probs = F.softmax(logits, dim=1)[0]
top5 = torch.topk(probs, 5)
return {LABELS[idx]: float(prob) for prob, idx in zip(top5.values, top5.indices)}
# ---------------------------------------------------------------------------
# Main function
# ---------------------------------------------------------------------------
def explain(image: Image.Image, model_name: str, target_class_name: str):
if image is None:
return None, None, None, {}
img = image.convert("RGB")
inp = preprocess(img).unsqueeze(0).to(device)
model = MODELS[model_name]
gradcam = gradcams[model_name]
# Forward pass for top-5
with torch.no_grad():
logits = model(inp)
top5 = get_top5(logits)
# Determine target class
if target_class_name and target_class_name in LABELS:
target_idx = LABELS.index(target_class_name)
else:
target_idx = None # use argmax
# Generate GradCAM
cam, used_class = gradcam.generate(inp, target_idx)
# Prepare display images
display_img = img.resize((224, 224))
img_np = np.array(display_img)
# Heatmap
heatmap = cv2.applyColorMap((cam * 255).astype(np.uint8), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# Overlay
overlay = (img_np * 0.5 + heatmap * 0.5).astype(np.uint8)
return img_np, heatmap, overlay, top5
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="GradCAM Explainer") as demo:
gr.Markdown(
"# GradCAM Explainer\n"
"Upload an image to visualize which regions a CNN focuses on for its prediction.\n"
"*Course: 215 AI Safety — Explainability*"
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
model_choice = gr.Dropdown(
list(MODELS.keys()), value="ResNet-50", label="Model"
)
target_class = gr.Textbox(
label="Target Class (optional)",
placeholder="Leave empty for top prediction",
)
run_btn = gr.Button("Generate GradCAM", variant="primary")
with gr.Column(scale=2):
with gr.Row():
orig_out = gr.Image(label="Original (224x224)")
heat_out = gr.Image(label="GradCAM Heatmap")
over_out = gr.Image(label="Overlay")
top5_out = gr.Label(num_top_classes=5, label="Top-5 Predictions")
run_btn.click(
fn=explain,
inputs=[input_image, model_choice, target_class],
outputs=[orig_out, heat_out, over_out, top5_out],
)
gr.Examples(
examples=[
["examples/cat.jpg", "ResNet-50", ""],
["examples/dog.jpg", "ResNet-50", ""],
],
inputs=[input_image, model_choice, target_class],
)
if __name__ == "__main__":
demo.launch()