Shivdutta commited on
Commit
2fd201f
·
verified ·
1 Parent(s): c61ddd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -29
app.py CHANGED
@@ -19,7 +19,7 @@ classes = ('plane', 'car', 'bird', 'cat', 'deer',
19
  model = LITResNet(classes)
20
  model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
21
 
22
- def inference(input_img, input_img_label, show_gradcam, num_gradcam, layer_num, opacity, show_misclassified, num_misclassified, num_top_classes):
23
  input_img = np.array(Image.fromarray(np.array(input_img)).resize((32,32)))
24
  org_img = input_img
25
 
@@ -30,60 +30,52 @@ def inference(input_img, input_img_label, show_gradcam, num_gradcam, layer_num,
30
  softmax = torch.nn.Softmax(dim=0)
31
  o = softmax(outputs.flatten())
32
  confidences = {classes[i]: float(o[i]) for i in range(10)}
33
-
34
  _, prediction = torch.max(outputs, 1)
35
- is_misclassified = (prediction != classes.index(input_img_label))
36
 
37
- if show_gradcam:
38
- target_layers = [model.layer2[layer_num]]
39
  cam = GradCAM(model=model, target_layers=target_layers)
40
  grayscale_cam = cam(input_tensor=input_img, targets=None)
41
  grayscale_cam = grayscale_cam[0, :]
 
42
  img = input_img.squeeze(0)
43
  img = inv_normalize(img)
44
- visualization = [show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=opacity) for _ in range(num_gradcam)]
45
- else:
46
- visualization = []
47
 
48
- if show_misclassified:
49
- misclassified_imgs = [input_img for _ in range(num_misclassified)]
50
- else:
51
- misclassified_imgs = []
52
 
53
- sorted_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:num_top_classes])
 
54
 
55
- return prediction[0].item(), classes[prediction[0].item()], is_misclassified, sorted_confidences, visualization, misclassified_imgs
 
 
56
 
57
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
58
- description = "A simple Gradio interface to infer on ResNet model, get GradCAM results, and view misclassified images"
59
 
60
  examples = [
61
- ["plane.jpeg", "plane", True, 1, -1, 0.5, False, 0, 3],
62
- ["car.jpeg", "car", True, 2, -2, 0.7, True, 1, 5],
63
- ["bird.jpeg", "bird", False, 0, -1, 0.5, False, 0, 3],
64
- # Add more examples as needed
65
  ]
66
 
67
  demo = gr.Interface(
68
  inference,
69
  inputs=[
70
  gr.Image(width=256, height=256, label="Input Image"),
71
- gr.Dropdown(choices=classes, label="Ground Truth Label"),
72
- gr.Checkbox(value=True, label="Show GradCAM"),
73
  gr.Slider(1, 5, value=1, step=1, label="Number of GradCAM Images"),
74
  gr.Slider(-2, -1, value=-2, step=1, label="Which Layer?"),
75
  gr.Slider(0, 1, value=0.5, label="Overall Opacity of Image"),
76
- gr.Checkbox(value=False, label="Show Misclassified Images"),
77
- gr.Slider(1, 5, value=1, step=1, label="Number of Misclassified Images"),
78
- gr.Slider(1, 10, value=3, step=1, label="Number of Top Classes to Show")
79
  ],
80
  outputs=[
81
  "text",
82
- "text",
83
- "text",
84
- gr.Label(num_top_classes=10),
85
- gr.Gallery(label="GradCAM Visualizations"),
86
- gr.Gallery(label="Misclassified Images")
87
  ],
88
  title=title,
89
  description=description,
 
19
  model = LITResNet(classes)
20
  model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
21
 
22
+ def inference(input_img, num_gradcam_images=1, target_layer_number=-1, transparency=0.5, show_misclassified=False, num_top_classes=3):
23
  input_img = np.array(Image.fromarray(np.array(input_img)).resize((32,32)))
24
  org_img = input_img
25
 
 
30
  softmax = torch.nn.Softmax(dim=0)
31
  o = softmax(outputs.flatten())
32
  confidences = {classes[i]: float(o[i]) for i in range(10)}
 
33
  _, prediction = torch.max(outputs, 1)
 
34
 
35
+ if not show_misclassified or prediction[0].item() == np.argmax(list(confidences.values())):
36
+ target_layers = [model.layer2[target_layer_number]]
37
  cam = GradCAM(model=model, target_layers=target_layers)
38
  grayscale_cam = cam(input_tensor=input_img, targets=None)
39
  grayscale_cam = grayscale_cam[0, :]
40
+
41
  img = input_img.squeeze(0)
42
  img = inv_normalize(img)
 
 
 
43
 
44
+ visualizations = []
45
+ for _ in range(num_gradcam_images):
46
+ visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
47
+ visualizations.append(visualization)
48
 
49
+ top_classes = sorted(confidences.items(), key=lambda x: x[1], reverse=True)[:num_top_classes]
50
+ top_classes = [f"{cls}: {conf:.2f}" for cls, conf in top_classes]
51
 
52
+ return prediction[0].item(), visualizations, top_classes
53
+ else:
54
+ return None, None, None
55
 
56
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
57
+ description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
58
 
59
  examples = [
60
+ ["cat.jpg", 1, -1, 0.5, False, 3],
61
+ ["dog.jpg", 1, -1, 0.5, False, 3],
62
+ # Add more example images here
 
63
  ]
64
 
65
  demo = gr.Interface(
66
  inference,
67
  inputs=[
68
  gr.Image(width=256, height=256, label="Input Image"),
 
 
69
  gr.Slider(1, 5, value=1, step=1, label="Number of GradCAM Images"),
70
  gr.Slider(-2, -1, value=-2, step=1, label="Which Layer?"),
71
  gr.Slider(0, 1, value=0.5, label="Overall Opacity of Image"),
72
+ gr.Checkbox(label="Show Misclassified Images"),
73
+ gr.Slider(1, 10, value=3, step=1, label="Number of Top Classes")
 
74
  ],
75
  outputs=[
76
  "text",
77
+ gr.Gallery(label="GradCAM Images"),
78
+ gr.Label(num_top_classes=3, label="Top Classes")
 
 
 
79
  ],
80
  title=title,
81
  description=description,