Shivdutta commited on
Commit
fe72ddb
·
verified ·
1 Parent(s): 186a065

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -38
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from torchvision import transforms
3
  import numpy as np
4
  import gradio as gr
@@ -11,70 +12,81 @@ inv_normalize = transforms.Normalize(
11
  mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
12
  std=[1/0.23, 1/0.23, 1/0.23]
13
  )
 
14
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
15
  'dog', 'frog', 'horse', 'ship', 'truck')
16
 
17
  model = LITResNet(classes)
18
  model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
19
 
20
- def inference(input_img, gradcam=False, num_gradcam=1, transparency=0.5, target_layer_number=-1,
21
- misclassified=False, num_misclassified=1, top_classes=3):
22
- input_img = np.array(input_img)
23
  org_img = input_img
24
- input_img = Image.fromarray(input_img)
25
- transform = transforms.Compose([
26
- transforms.Resize((32, 32)),
27
- transforms.ToTensor(),
28
- ])
29
- input_img = transform(input_img)
30
- input_img = inv_normalize(input_img)
31
- input_img = input_img.unsqueeze(0)
32
  outputs = model(input_img)
33
- softmax = torch.nn.Softmax(dim=1)
34
- o = softmax(outputs)
35
- confidences = {classes[i]: float(o[0][i]) for i in range(10)}
36
- _, prediction = torch.max(outputs, 1)
37
- top_indices = torch.topk(outputs, top_classes).indices.squeeze(0).tolist()
38
 
39
- results = [classes[idx] for idx in top_indices]
 
40
 
41
- if gradcam:
42
- target_layers = [model.layer2[target_layer_number]]
43
  cam = GradCAM(model=model, target_layers=target_layers)
44
- grayscale_cams = cam(input_tensor=input_img, targets=None)
45
- visualizations = []
46
- for i in range(num_gradcam):
47
- grayscale_cam = grayscale_cams[i, :]
48
- visualization = show_cam_on_image(org_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency)
49
- visualizations.append(visualization)
50
  else:
51
- visualizations = None
52
 
53
- return results, visualizations, confidences
 
 
 
 
 
 
 
54
 
55
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
56
- description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
57
- examples = [["cat.jpg", False, 1, 0.5, -1, False, 1, 3]]
 
 
 
 
 
 
58
 
59
  demo = gr.Interface(
60
  inference,
61
  inputs=[
62
  gr.Image(width=256, height=256, label="Input Image"),
63
- "checkbox",
64
- gr.Number(value=1, label="Number of GradCAM Images"),
 
65
  gr.Slider(0, 1, value=0.5, label="Overall Opacity of Image"),
66
- gr.Slider(-1, -2, value=-1, step=1, label="Which Layer?"),
67
- "checkbox",
68
- gr.Number(value=1, label="Number of Misclassified Images"),
69
- gr.Number(value=3, minimum=1, maximum=10, label="Top Classes to Show")
70
  ],
71
  outputs=[
72
- gr.Label(num_top_classes=3),
73
- gr.Image(width=256, height=256, label="GradCAM Image", type="numpy"),
74
- gr.Label(num_top_classes=3)
 
 
 
75
  ],
76
  title=title,
77
  description=description,
78
  examples=examples,
79
  )
 
80
  demo.launch()
 
1
  import torch
2
+ import torchvision
3
  from torchvision import transforms
4
  import numpy as np
5
  import gradio as gr
 
12
  mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
13
  std=[1/0.23, 1/0.23, 1/0.23]
14
  )
15
+
16
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
17
  'dog', 'frog', 'horse', 'ship', 'truck')
18
 
19
  model = LITResNet(classes)
20
  model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
21
 
22
+ def inference(input_img, show_gradcam, num_gradcam, layer_num, opacity, show_misclassified, num_misclassified, num_top_classes):
23
+ input_img = np.array(Image.fromarray(np.array(input_img)).resize((32,32)))
 
24
  org_img = input_img
25
+
26
+ transform = transforms.ToTensor()
27
+ input_img = transform(input_img).unsqueeze(0)
28
+
 
 
 
 
29
  outputs = model(input_img)
30
+ softmax = torch.nn.Softmax(dim=0)
31
+ o = softmax(outputs.flatten())
32
+ confidences = {classes[i]: float(o[i]) for i in range(10)}
 
 
33
 
34
+ _, prediction = torch.max(outputs, 1)
35
+ is_misclassified = (prediction != labels.index(input_img_label))
36
 
37
+ if show_gradcam:
38
+ target_layers = [model.layer2[layer_num]]
39
  cam = GradCAM(model=model, target_layers=target_layers)
40
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
41
+ grayscale_cam = grayscale_cam[0, :]
42
+ img = input_img.squeeze(0)
43
+ img = inv_normalize(img)
44
+ visualization = [show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=opacity) for _ in range(num_gradcam)]
 
45
  else:
46
+ visualization = []
47
 
48
+ if show_misclassified:
49
+ misclassified_imgs = [input_img for _ in range(num_misclassified)]
50
+ else:
51
+ misclassified_imgs = []
52
+
53
+ sorted_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:num_top_classes])
54
+
55
+ return prediction[0].item(), classes[prediction[0].item()], is_misclassified, sorted_confidences, visualization, misclassified_imgs
56
 
57
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
58
+ description = "A simple Gradio interface to infer on ResNet model, get GradCAM results, and view misclassified images"
59
+
60
+ examples = [
61
+ ["plane.jpg", True, 1, -1, 0.5, False, 0, 3],
62
+ ["car.jpg", True, 2, -2, 0.7, True, 1, 5],
63
+ ["bird.jpg", False, 0, -1, 0.5, False, 0, 3],
64
+ # Add more examples as needed
65
+ ]
66
 
67
  demo = gr.Interface(
68
  inference,
69
  inputs=[
70
  gr.Image(width=256, height=256, label="Input Image"),
71
+ gr.Checkbox(value=True, label="Show GradCAM"),
72
+ gr.Slider(1, 5, value=1, step=1, label="Number of GradCAM Images"),
73
+ gr.Slider(-2, -1, value=-2, step=1, label="Which Layer?"),
74
  gr.Slider(0, 1, value=0.5, label="Overall Opacity of Image"),
75
+ gr.Checkbox(value=False, label="Show Misclassified Images"),
76
+ gr.Slider(1, 5, value=1, step=1, label="Number of Misclassified Images"),
77
+ gr.Slider(1, 10, value=3, step=1, label="Number of Top Classes to Show")
 
78
  ],
79
  outputs=[
80
+ "text",
81
+ "text",
82
+ "text",
83
+ gr.Label(num_top_classes=10),
84
+ gr.Gallery(label="GradCAM Visualizations"),
85
+ gr.Gallery(label="Misclassified Images")
86
  ],
87
  title=title,
88
  description=description,
89
  examples=examples,
90
  )
91
+
92
  demo.launch()