File size: 3,374 Bytes
519ffcb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import cv2
import requests
import gradio as gr
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"
weights = torchvision.models.ResNet18_Weights.DEFAULT
model = torchvision.models.resnet18(weights=weights)
model.eval()
model.to(device)
labels_url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
labels = requests.get(labels_url).json()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def get_top3_predictions(img):
img = img.convert("RGB")
input_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
probs = torch.nn.functional.softmax(output[0], dim=0)
top_probs, top_idxs = torch.topk(probs, 3)
results = []
for prob, idx in zip(top_probs, top_idxs):
label = labels[str(idx.item())][1]
results.append(f"{label}: {prob.item():.4f}")
return "\n".join(results)
def generate_gradcam(img):
img = img.convert("RGB")
input_tensor = transform(img).unsqueeze(0).to(device)
gradients = []
activations = []
def forward_hook(module, input, output):
activations.append(output)
def backward_hook(module, grad_input, grad_output):
gradients.append(grad_output[0])
target_layer = model.layer4[1].conv2
forward_handle = target_layer.register_forward_hook(forward_hook)
backward_handle = target_layer.register_full_backward_hook(backward_hook)
output = model(input_tensor)
pred_class = output.argmax()
model.zero_grad()
output[0, pred_class].backward()
grads = gradients[0]
acts = activations[0]
weights_cam = grads.mean(dim=[2, 3], keepdim=True)
cam = (weights_cam * acts).sum(dim=1)
cam = torch.relu(cam)
cam = cam.squeeze().detach().cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
heatmap = cv2.resize(cam, (img.size[0], img.size[1]))
heatmap_uint8 = np.uint8(255 * heatmap)
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
img_np = np.array(img)
overlay = cv2.addWeighted(img_np, 0.6, heatmap_color, 0.4, 0)
pred_label = labels[str(pred_class.item())][1]
forward_handle.remove()
backward_handle.remove()
return Image.fromarray(overlay), pred_label
def gradcam_app(image):
if image is None:
return None, "No image uploaded.", ""
overlay, pred_label = generate_gradcam(image)
top3 = get_top3_predictions(image)
return overlay, pred_label, top3
demo = gr.Interface(
fn=gradcam_app,
inputs=gr.Image(type="pil", label="Upload an image"),
outputs=[
gr.Image(type="pil", label="Grad-CAM Overlay"),
gr.Textbox(label="Predicted Class"),
gr.Textbox(label="Top-3 Predictions")
],
title="Vision Model Interpretability with Grad-CAM",
description="Upload an image to see a ResNet-18 prediction, top-3 classes, and a Grad-CAM heatmap."
)
if __name__ == "__main__":
demo.launch() |