Spaces:
Runtime error
Runtime error
Commit
·
f5c7805
1
Parent(s):
eab3f1d
added app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from pytorch_grad_cam import GradCAM
|
| 5 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 6 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 7 |
+
import torch
|
| 8 |
+
from torchvision import datasets, transforms
|
| 9 |
+
from model import LightningDavidNet
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
model = LightningDavidNet()
|
| 14 |
+
model.load_from_checkpoint('model.ckpt')
|
| 15 |
+
model.eval()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
| 19 |
+
'dog', 'frog', 'horse', 'ship', 'truck')
|
| 20 |
+
|
| 21 |
+
images = []
|
| 22 |
+
|
| 23 |
+
def run_model(input_img, input_radio_gradcam, transparency = 0.5, target_layer = 3, input_slider_classes = 3):
|
| 24 |
+
mean=[0.49139968, 0.48215827, 0.44653124]
|
| 25 |
+
std=[0.24703233, 0.24348505, 0.26158768]
|
| 26 |
+
transform = transforms.Compose([
|
| 27 |
+
transforms.ToTensor(),
|
| 28 |
+
transforms.Normalize(mean, std)
|
| 29 |
+
])
|
| 30 |
+
orginal_img = input_img
|
| 31 |
+
input_img = transform(input_img)
|
| 32 |
+
input_img = input_img.unsqueeze(0)
|
| 33 |
+
outputs = model(input_img)
|
| 34 |
+
softmax = torch.nn.Softmax(dim=0)
|
| 35 |
+
o = softmax(outputs.flatten())
|
| 36 |
+
confidences = {classes[i]: float(o[i]) for i in range(10)}
|
| 37 |
+
if input_radio_gradcam == "No":
|
| 38 |
+
return confidences, orginal_img
|
| 39 |
+
_, prediction = torch.max(outputs, 1)
|
| 40 |
+
target_layers = [model.r2.block1[0]]
|
| 41 |
+
if target_layer == 1:
|
| 42 |
+
target_layers = [model.l2X[0]]
|
| 43 |
+
if target_layer == 2:
|
| 44 |
+
target_layers = [model.l3X[0]]
|
| 45 |
+
if target_layer == 3:
|
| 46 |
+
target_layers = [model.r2.block1[0]]
|
| 47 |
+
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
|
| 48 |
+
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
| 49 |
+
grayscale_cam = grayscale_cam[0, :]
|
| 50 |
+
visualization = show_cam_on_image(orginal_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
|
| 51 |
+
|
| 52 |
+
return confidences, visualization
|
| 53 |
+
|
| 54 |
+
def inference(input_img, input_radio_gradcam, transparency = 0.5, target_layer = 3, input_slider_classes = 3, input_radio_misclassification="No",input_slider_misclassified=29):
|
| 55 |
+
confidences, visualization = run_model(input_img, input_radio_gradcam, transparency, target_layer, input_slider_classes)
|
| 56 |
+
if input_radio_misclassification =="Yes":
|
| 57 |
+
images = get_images()
|
| 58 |
+
misclassified_output_box.visible = True
|
| 59 |
+
return confidences, visualization,images[:input_slider_misclassified]
|
| 60 |
+
else:
|
| 61 |
+
return confidences, visualization,None
|
| 62 |
+
|
| 63 |
+
def change_gradcam_view(choice):
|
| 64 |
+
if choice == "Yes":
|
| 65 |
+
return gradcam_dialog_box.update(visible=True)
|
| 66 |
+
else:
|
| 67 |
+
return gradcam_dialog_box.update(visible=False)
|
| 68 |
+
|
| 69 |
+
def update_top_classes(input_img, input_slider_gradcam, transparency, target_layer_number, topk):
|
| 70 |
+
output_classes.num_top_classes=topk
|
| 71 |
+
return inference(input_img, input_slider_gradcam, transparency, target_layer_number, topk)[0]
|
| 72 |
+
|
| 73 |
+
def change_missclassified_view(choice):
|
| 74 |
+
if choice == "Yes":
|
| 75 |
+
return misclassified_dialog_box.update(visible=True)
|
| 76 |
+
else:
|
| 77 |
+
return misclassified_dialog_box.update(visible=False)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_images():
|
| 81 |
+
counter = 29
|
| 82 |
+
if images == []:
|
| 83 |
+
while counter>0:
|
| 84 |
+
image_path = f'/content/Misclassified_images/{counter}.jpg'
|
| 85 |
+
images.append(image_path)
|
| 86 |
+
counter -=1
|
| 87 |
+
return images
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def show_misclassified_images(number_of_missclassified, gradcam, transparency, target_layer):
|
| 91 |
+
images = get_images()
|
| 92 |
+
output_gallery = []
|
| 93 |
+
for image_path in images:
|
| 94 |
+
image = Image.open(image_path)
|
| 95 |
+
image_array = np.asarray(image)
|
| 96 |
+
visualization = inference(image_array, gradcam, transparency, target_layer)[-1]
|
| 97 |
+
output_gallery.append(visualization)
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
misclassified_output_box: gr.update(visible=True),
|
| 101 |
+
gallery: output_gallery[:number_of_missclassified]
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
with gr.Blocks() as demo:
|
| 105 |
+
gr.Markdown("# Lighting DavidNet")
|
| 106 |
+
gr.Markdown("### CIFAR 10 Classifier with GradCAM with DavidNet")
|
| 107 |
+
gr.Markdown("## Classification")
|
| 108 |
+
with gr.Row():
|
| 109 |
+
with gr.Column(scale=1):
|
| 110 |
+
input_image = gr.Image(shape=(32, 32), label="Input Image")
|
| 111 |
+
with gr.Row():
|
| 112 |
+
clear_btn_main = gr.ClearButton()
|
| 113 |
+
submit_btn_main = gr.Button("Submit")
|
| 114 |
+
with gr.Accordion("Advanced options", open=False):
|
| 115 |
+
|
| 116 |
+
input_radio_gradcam = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay GradCAM output")
|
| 117 |
+
with gr.Column(visible=False) as gradcam_dialog_box:
|
| 118 |
+
input_slider1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")
|
| 119 |
+
input_slider2 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?")
|
| 120 |
+
input_slider_classes = gr.Slider(1, 10, value = 3, step=1, label="How Many Classes you want to see?")
|
| 121 |
+
input_radio_misclassification = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see misclassified images?")
|
| 122 |
+
with gr.Column(visible=False) as misclassified_dialog_box:
|
| 123 |
+
input_slider_misclassified = gr.Slider(1, 29, value = 29, step=1, label="Number of misclassified images to view?")
|
| 124 |
+
|
| 125 |
+
with gr.Column(scale=1):
|
| 126 |
+
output_classes = gr.Label(num_top_classes=3,label="Output Labels(Default: 3)")
|
| 127 |
+
output_image = gr.Image(shape=(32, 32), label="Classification Output(Default: Without GradCAM)").style(width=512, height=512)
|
| 128 |
+
with gr.Column(visible=True) as misclassified_output_box:
|
| 129 |
+
gallery = gr.Gallery(label="Misclassified Gallery", show_label=False, elem_id="gallery").style(columns=[5], rows=[6], object_fit="contain", height="auto")
|
| 130 |
+
|
| 131 |
+
submit_btn_main.click(
|
| 132 |
+
fn=inference, inputs=[
|
| 133 |
+
input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes,
|
| 134 |
+
input_radio_misclassification,input_slider_misclassified
|
| 135 |
+
],
|
| 136 |
+
outputs=[
|
| 137 |
+
output_classes,
|
| 138 |
+
output_image,
|
| 139 |
+
gallery
|
| 140 |
+
]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
clear_btn_main.click(
|
| 144 |
+
lambda: [None, "No", 0.5, 3, 3,"No",3,3, None,None],
|
| 145 |
+
outputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes, input_radio_misclassification,input_slider_misclassified, output_classes, output_image, gallery])
|
| 146 |
+
input_slider_classes.change(update_top_classes, inputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes], outputs=[output_classes])
|
| 147 |
+
input_radio_gradcam.change(fn=change_gradcam_view, inputs=input_radio_gradcam, outputs=[gradcam_dialog_box])
|
| 148 |
+
input_radio_misclassification.change(fn=change_missclassified_view, inputs=input_radio_misclassification, outputs=[misclassified_dialog_box])
|
| 149 |
+
with gr.Row():
|
| 150 |
+
with gr.Column(scale=1):
|
| 151 |
+
gr.Markdown("## Examples")
|
| 152 |
+
gr.Examples(
|
| 153 |
+
examples=[["Examples/1.jpg", "Yes", 0.5, 3, 3,"Yes",29],
|
| 154 |
+
["Examples/2.jpg", "Yes", 0.7, 2, 5,"Yes",29],
|
| 155 |
+
["Examples/3.jpg", "Yes", 0.9, 1, 4,"Yes",29],
|
| 156 |
+
["Examples/4.jpg", "Yes", 0.3, 1, 7,"Yes",29],
|
| 157 |
+
["Examples/5.jpg", "Yes", 0.7, 3, 4,"Yes",29],
|
| 158 |
+
["Examples/6.jpg", "Yes", 0.8, 3, 6,"Yes",29],
|
| 159 |
+
["Examples/7.jpg", "Yes", 0.9, 1, 7,"Yes",29],
|
| 160 |
+
["Examples/8.jpg", "Yes", 0.3, 1, 3,"Yes",29],
|
| 161 |
+
["Examples/9.jpg", "Yes", 0.4, 3, 4,"Yes",29],
|
| 162 |
+
["Examples/10.jpg", "Yes", 0.5, 2, 5,"Yes",29]
|
| 163 |
+
],
|
| 164 |
+
inputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes,
|
| 165 |
+
input_radio_misclassification,input_slider_misclassified],
|
| 166 |
+
outputs=[output_classes, output_image,gallery],
|
| 167 |
+
fn=inference,
|
| 168 |
+
cache_examples=True,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
demo.launch(debug=False)
|
| 173 |
+
# demo.launch(share=True,debug = True)
|