File size: 3,374 Bytes
519ffcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import cv2
import requests
import gradio as gr
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

weights = torchvision.models.ResNet18_Weights.DEFAULT
model = torchvision.models.resnet18(weights=weights)
model.eval()
model.to(device)

labels_url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
labels = requests.get(labels_url).json()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

def get_top3_predictions(img):
    img = img.convert("RGB")
    input_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)
        probs = torch.nn.functional.softmax(output[0], dim=0)
        top_probs, top_idxs = torch.topk(probs, 3)

    results = []
    for prob, idx in zip(top_probs, top_idxs):
        label = labels[str(idx.item())][1]
        results.append(f"{label}: {prob.item():.4f}")

    return "\n".join(results)

def generate_gradcam(img):
    img = img.convert("RGB")
    input_tensor = transform(img).unsqueeze(0).to(device)

    gradients = []
    activations = []

    def forward_hook(module, input, output):
        activations.append(output)

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    target_layer = model.layer4[1].conv2
    forward_handle = target_layer.register_forward_hook(forward_hook)
    backward_handle = target_layer.register_full_backward_hook(backward_hook)

    output = model(input_tensor)
    pred_class = output.argmax()

    model.zero_grad()
    output[0, pred_class].backward()

    grads = gradients[0]
    acts = activations[0]

    weights_cam = grads.mean(dim=[2, 3], keepdim=True)
    cam = (weights_cam * acts).sum(dim=1)
    cam = torch.relu(cam)

    cam = cam.squeeze().detach().cpu().numpy()
    cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)

    heatmap = cv2.resize(cam, (img.size[0], img.size[1]))
    heatmap_uint8 = np.uint8(255 * heatmap)
    heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
    heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)

    img_np = np.array(img)
    overlay = cv2.addWeighted(img_np, 0.6, heatmap_color, 0.4, 0)

    pred_label = labels[str(pred_class.item())][1]

    forward_handle.remove()
    backward_handle.remove()

    return Image.fromarray(overlay), pred_label

def gradcam_app(image):
    if image is None:
        return None, "No image uploaded.", ""

    overlay, pred_label = generate_gradcam(image)
    top3 = get_top3_predictions(image)

    return overlay, pred_label, top3

demo = gr.Interface(
    fn=gradcam_app,
    inputs=gr.Image(type="pil", label="Upload an image"),
    outputs=[
        gr.Image(type="pil", label="Grad-CAM Overlay"),
        gr.Textbox(label="Predicted Class"),
        gr.Textbox(label="Top-3 Predictions")
    ],
    title="Vision Model Interpretability with Grad-CAM",
    description="Upload an image to see a ResNet-18 prediction, top-3 classes, and a Grad-CAM heatmap."
)

if __name__ == "__main__":
    demo.launch()