import torch import torchvision import torchvision.transforms as transforms import numpy as np import cv2 import requests import gradio as gr from io import BytesIO from PIL import Image import matplotlib.pyplot as plt device = "cuda" if torch.cuda.is_available() else "cpu" weights = torchvision.models.ResNet18_Weights.DEFAULT model = torchvision.models.resnet18(weights=weights) model.eval() model.to(device) labels_url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json" labels = requests.get(labels_url).json() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def get_top3_predictions(img): img = img.convert("RGB") input_tensor = transform(img).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) probs = torch.nn.functional.softmax(output[0], dim=0) top_probs, top_idxs = torch.topk(probs, 3) results = [] for prob, idx in zip(top_probs, top_idxs): label = labels[str(idx.item())][1] results.append(f"{label}: {prob.item():.4f}") return "\n".join(results) def generate_gradcam(img): img = img.convert("RGB") input_tensor = transform(img).unsqueeze(0).to(device) gradients = [] activations = [] def forward_hook(module, input, output): activations.append(output) def backward_hook(module, grad_input, grad_output): gradients.append(grad_output[0]) target_layer = model.layer4[1].conv2 forward_handle = target_layer.register_forward_hook(forward_hook) backward_handle = target_layer.register_full_backward_hook(backward_hook) output = model(input_tensor) pred_class = output.argmax() model.zero_grad() output[0, pred_class].backward() grads = gradients[0] acts = activations[0] weights_cam = grads.mean(dim=[2, 3], keepdim=True) cam = (weights_cam * acts).sum(dim=1) cam = torch.relu(cam) cam = cam.squeeze().detach().cpu().numpy() cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) heatmap = cv2.resize(cam, (img.size[0], img.size[1])) heatmap_uint8 = np.uint8(255 * heatmap) heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) img_np = np.array(img) overlay = cv2.addWeighted(img_np, 0.6, heatmap_color, 0.4, 0) pred_label = labels[str(pred_class.item())][1] forward_handle.remove() backward_handle.remove() return Image.fromarray(overlay), pred_label def gradcam_app(image): if image is None: return None, "No image uploaded.", "" overlay, pred_label = generate_gradcam(image) top3 = get_top3_predictions(image) return overlay, pred_label, top3 demo = gr.Interface( fn=gradcam_app, inputs=gr.Image(type="pil", label="Upload an image"), outputs=[ gr.Image(type="pil", label="Grad-CAM Overlay"), gr.Textbox(label="Predicted Class"), gr.Textbox(label="Top-3 Predictions") ], title="Vision Model Interpretability with Grad-CAM", description="Upload an image to see a ResNet-18 prediction, top-3 classes, and a Grad-CAM heatmap." ) if __name__ == "__main__": demo.launch()