Megatron17 commited on
Commit
f5c7805
·
1 Parent(s): eab3f1d

added app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from pytorch_grad_cam import GradCAM
5
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
6
+ from pytorch_grad_cam.utils.image import show_cam_on_image
7
+ import torch
8
+ from torchvision import datasets, transforms
9
+ from model import LightningDavidNet
10
+ import random
11
+
12
+
13
+ model = LightningDavidNet()
14
+ model.load_from_checkpoint('model.ckpt')
15
+ model.eval()
16
+
17
+
18
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
19
+ 'dog', 'frog', 'horse', 'ship', 'truck')
20
+
21
+ images = []
22
+
23
+ def run_model(input_img, input_radio_gradcam, transparency = 0.5, target_layer = 3, input_slider_classes = 3):
24
+ mean=[0.49139968, 0.48215827, 0.44653124]
25
+ std=[0.24703233, 0.24348505, 0.26158768]
26
+ transform = transforms.Compose([
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(mean, std)
29
+ ])
30
+ orginal_img = input_img
31
+ input_img = transform(input_img)
32
+ input_img = input_img.unsqueeze(0)
33
+ outputs = model(input_img)
34
+ softmax = torch.nn.Softmax(dim=0)
35
+ o = softmax(outputs.flatten())
36
+ confidences = {classes[i]: float(o[i]) for i in range(10)}
37
+ if input_radio_gradcam == "No":
38
+ return confidences, orginal_img
39
+ _, prediction = torch.max(outputs, 1)
40
+ target_layers = [model.r2.block1[0]]
41
+ if target_layer == 1:
42
+ target_layers = [model.l2X[0]]
43
+ if target_layer == 2:
44
+ target_layers = [model.l3X[0]]
45
+ if target_layer == 3:
46
+ target_layers = [model.r2.block1[0]]
47
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
48
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
49
+ grayscale_cam = grayscale_cam[0, :]
50
+ visualization = show_cam_on_image(orginal_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
51
+
52
+ return confidences, visualization
53
+
54
+ def inference(input_img, input_radio_gradcam, transparency = 0.5, target_layer = 3, input_slider_classes = 3, input_radio_misclassification="No",input_slider_misclassified=29):
55
+ confidences, visualization = run_model(input_img, input_radio_gradcam, transparency, target_layer, input_slider_classes)
56
+ if input_radio_misclassification =="Yes":
57
+ images = get_images()
58
+ misclassified_output_box.visible = True
59
+ return confidences, visualization,images[:input_slider_misclassified]
60
+ else:
61
+ return confidences, visualization,None
62
+
63
+ def change_gradcam_view(choice):
64
+ if choice == "Yes":
65
+ return gradcam_dialog_box.update(visible=True)
66
+ else:
67
+ return gradcam_dialog_box.update(visible=False)
68
+
69
+ def update_top_classes(input_img, input_slider_gradcam, transparency, target_layer_number, topk):
70
+ output_classes.num_top_classes=topk
71
+ return inference(input_img, input_slider_gradcam, transparency, target_layer_number, topk)[0]
72
+
73
+ def change_missclassified_view(choice):
74
+ if choice == "Yes":
75
+ return misclassified_dialog_box.update(visible=True)
76
+ else:
77
+ return misclassified_dialog_box.update(visible=False)
78
+
79
+
80
+ def get_images():
81
+ counter = 29
82
+ if images == []:
83
+ while counter>0:
84
+ image_path = f'/content/Misclassified_images/{counter}.jpg'
85
+ images.append(image_path)
86
+ counter -=1
87
+ return images
88
+
89
+
90
+ def show_misclassified_images(number_of_missclassified, gradcam, transparency, target_layer):
91
+ images = get_images()
92
+ output_gallery = []
93
+ for image_path in images:
94
+ image = Image.open(image_path)
95
+ image_array = np.asarray(image)
96
+ visualization = inference(image_array, gradcam, transparency, target_layer)[-1]
97
+ output_gallery.append(visualization)
98
+
99
+ return {
100
+ misclassified_output_box: gr.update(visible=True),
101
+ gallery: output_gallery[:number_of_missclassified]
102
+ }
103
+
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("# Lighting DavidNet")
106
+ gr.Markdown("### CIFAR 10 Classifier with GradCAM with DavidNet")
107
+ gr.Markdown("## Classification")
108
+ with gr.Row():
109
+ with gr.Column(scale=1):
110
+ input_image = gr.Image(shape=(32, 32), label="Input Image")
111
+ with gr.Row():
112
+ clear_btn_main = gr.ClearButton()
113
+ submit_btn_main = gr.Button("Submit")
114
+ with gr.Accordion("Advanced options", open=False):
115
+
116
+ input_radio_gradcam = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay GradCAM output")
117
+ with gr.Column(visible=False) as gradcam_dialog_box:
118
+ input_slider1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")
119
+ input_slider2 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?")
120
+ input_slider_classes = gr.Slider(1, 10, value = 3, step=1, label="How Many Classes you want to see?")
121
+ input_radio_misclassification = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see misclassified images?")
122
+ with gr.Column(visible=False) as misclassified_dialog_box:
123
+ input_slider_misclassified = gr.Slider(1, 29, value = 29, step=1, label="Number of misclassified images to view?")
124
+
125
+ with gr.Column(scale=1):
126
+ output_classes = gr.Label(num_top_classes=3,label="Output Labels(Default: 3)")
127
+ output_image = gr.Image(shape=(32, 32), label="Classification Output(Default: Without GradCAM)").style(width=512, height=512)
128
+ with gr.Column(visible=True) as misclassified_output_box:
129
+ gallery = gr.Gallery(label="Misclassified Gallery", show_label=False, elem_id="gallery").style(columns=[5], rows=[6], object_fit="contain", height="auto")
130
+
131
+ submit_btn_main.click(
132
+ fn=inference, inputs=[
133
+ input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes,
134
+ input_radio_misclassification,input_slider_misclassified
135
+ ],
136
+ outputs=[
137
+ output_classes,
138
+ output_image,
139
+ gallery
140
+ ]
141
+ )
142
+
143
+ clear_btn_main.click(
144
+ lambda: [None, "No", 0.5, 3, 3,"No",3,3, None,None],
145
+ outputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes, input_radio_misclassification,input_slider_misclassified, output_classes, output_image, gallery])
146
+ input_slider_classes.change(update_top_classes, inputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes], outputs=[output_classes])
147
+ input_radio_gradcam.change(fn=change_gradcam_view, inputs=input_radio_gradcam, outputs=[gradcam_dialog_box])
148
+ input_radio_misclassification.change(fn=change_missclassified_view, inputs=input_radio_misclassification, outputs=[misclassified_dialog_box])
149
+ with gr.Row():
150
+ with gr.Column(scale=1):
151
+ gr.Markdown("## Examples")
152
+ gr.Examples(
153
+ examples=[["Examples/1.jpg", "Yes", 0.5, 3, 3,"Yes",29],
154
+ ["Examples/2.jpg", "Yes", 0.7, 2, 5,"Yes",29],
155
+ ["Examples/3.jpg", "Yes", 0.9, 1, 4,"Yes",29],
156
+ ["Examples/4.jpg", "Yes", 0.3, 1, 7,"Yes",29],
157
+ ["Examples/5.jpg", "Yes", 0.7, 3, 4,"Yes",29],
158
+ ["Examples/6.jpg", "Yes", 0.8, 3, 6,"Yes",29],
159
+ ["Examples/7.jpg", "Yes", 0.9, 1, 7,"Yes",29],
160
+ ["Examples/8.jpg", "Yes", 0.3, 1, 3,"Yes",29],
161
+ ["Examples/9.jpg", "Yes", 0.4, 3, 4,"Yes",29],
162
+ ["Examples/10.jpg", "Yes", 0.5, 2, 5,"Yes",29]
163
+ ],
164
+ inputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes,
165
+ input_radio_misclassification,input_slider_misclassified],
166
+ outputs=[output_classes, output_image,gallery],
167
+ fn=inference,
168
+ cache_examples=True,
169
+ )
170
+
171
+ if __name__ == "__main__":
172
+ demo.launch(debug=False)
173
+ # demo.launch(share=True,debug = True)