dschechter27 commited on
Commit
519ffcb
·
verified ·
1 Parent(s): 39fd9bf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.transforms as transforms
4
+ import numpy as np
5
+ import cv2
6
+ import requests
7
+ import gradio as gr
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ import matplotlib.pyplot as plt
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ weights = torchvision.models.ResNet18_Weights.DEFAULT
15
+ model = torchvision.models.resnet18(weights=weights)
16
+ model.eval()
17
+ model.to(device)
18
+
19
+ labels_url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
20
+ labels = requests.get(labels_url).json()
21
+
22
+ transform = transforms.Compose([
23
+ transforms.Resize((224, 224)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(
26
+ mean=[0.485, 0.456, 0.406],
27
+ std=[0.229, 0.224, 0.225]
28
+ )
29
+ ])
30
+
31
+ def get_top3_predictions(img):
32
+ img = img.convert("RGB")
33
+ input_tensor = transform(img).unsqueeze(0).to(device)
34
+
35
+ with torch.no_grad():
36
+ output = model(input_tensor)
37
+ probs = torch.nn.functional.softmax(output[0], dim=0)
38
+ top_probs, top_idxs = torch.topk(probs, 3)
39
+
40
+ results = []
41
+ for prob, idx in zip(top_probs, top_idxs):
42
+ label = labels[str(idx.item())][1]
43
+ results.append(f"{label}: {prob.item():.4f}")
44
+
45
+ return "\n".join(results)
46
+
47
+ def generate_gradcam(img):
48
+ img = img.convert("RGB")
49
+ input_tensor = transform(img).unsqueeze(0).to(device)
50
+
51
+ gradients = []
52
+ activations = []
53
+
54
+ def forward_hook(module, input, output):
55
+ activations.append(output)
56
+
57
+ def backward_hook(module, grad_input, grad_output):
58
+ gradients.append(grad_output[0])
59
+
60
+ target_layer = model.layer4[1].conv2
61
+ forward_handle = target_layer.register_forward_hook(forward_hook)
62
+ backward_handle = target_layer.register_full_backward_hook(backward_hook)
63
+
64
+ output = model(input_tensor)
65
+ pred_class = output.argmax()
66
+
67
+ model.zero_grad()
68
+ output[0, pred_class].backward()
69
+
70
+ grads = gradients[0]
71
+ acts = activations[0]
72
+
73
+ weights_cam = grads.mean(dim=[2, 3], keepdim=True)
74
+ cam = (weights_cam * acts).sum(dim=1)
75
+ cam = torch.relu(cam)
76
+
77
+ cam = cam.squeeze().detach().cpu().numpy()
78
+ cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
79
+
80
+ heatmap = cv2.resize(cam, (img.size[0], img.size[1]))
81
+ heatmap_uint8 = np.uint8(255 * heatmap)
82
+ heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
83
+ heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
84
+
85
+ img_np = np.array(img)
86
+ overlay = cv2.addWeighted(img_np, 0.6, heatmap_color, 0.4, 0)
87
+
88
+ pred_label = labels[str(pred_class.item())][1]
89
+
90
+ forward_handle.remove()
91
+ backward_handle.remove()
92
+
93
+ return Image.fromarray(overlay), pred_label
94
+
95
+ def gradcam_app(image):
96
+ if image is None:
97
+ return None, "No image uploaded.", ""
98
+
99
+ overlay, pred_label = generate_gradcam(image)
100
+ top3 = get_top3_predictions(image)
101
+
102
+ return overlay, pred_label, top3
103
+
104
+ demo = gr.Interface(
105
+ fn=gradcam_app,
106
+ inputs=gr.Image(type="pil", label="Upload an image"),
107
+ outputs=[
108
+ gr.Image(type="pil", label="Grad-CAM Overlay"),
109
+ gr.Textbox(label="Predicted Class"),
110
+ gr.Textbox(label="Top-3 Predictions")
111
+ ],
112
+ title="Vision Model Interpretability with Grad-CAM",
113
+ description="Upload an image to see a ResNet-18 prediction, top-3 classes, and a Grad-CAM heatmap."
114
+ )
115
+
116
+ if __name__ == "__main__":
117
+ demo.launch()