Megatron17's picture
Update app.py
b1bbea6
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import torch
from torchvision import datasets, transforms
from model import LightningDavidNet
import random
model = LightningDavidNet()
model.load_from_checkpoint('model.pt')
model.eval()
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
images = []
def run_model(input_img, input_radio_gradcam, transparency = 0.5, target_layer = 3, input_slider_classes = 3):
mean=[0.49139968, 0.48215827, 0.44653124]
std=[0.24703233, 0.24348505, 0.26158768]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
orginal_img = input_img
input_img = transform(input_img)
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
if input_radio_gradcam == "No":
return confidences, orginal_img
_, prediction = torch.max(outputs, 1)
target_layers = [model.r2.block1[0]]
if target_layer == 1:
target_layers = [model.l2X[0]]
if target_layer == 2:
target_layers = [model.l3X[0]]
if target_layer == 3:
target_layers = [model.r2.block1[0]]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(orginal_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
return confidences, visualization
def inference(input_img, input_radio_gradcam, transparency = 0.5, target_layer = 3, input_slider_classes = 3, input_radio_misclassification="No",input_slider_misclassified=29):
confidences, visualization = run_model(input_img, input_radio_gradcam, transparency, target_layer, input_slider_classes)
if input_radio_misclassification =="Yes":
images = get_images()
misclassified_output_box.visible = True
return confidences, visualization,images[:input_slider_misclassified]
else:
return confidences, visualization,None
def change_gradcam_view(choice):
if choice == "Yes":
return gradcam_dialog_box.update(visible=True)
else:
return gradcam_dialog_box.update(visible=False)
def update_top_classes(input_img, input_slider_gradcam, transparency, target_layer_number, topk):
output_classes.num_top_classes=topk
return inference(input_img, input_slider_gradcam, transparency, target_layer_number, topk)[0]
def change_missclassified_view(choice):
if choice == "Yes":
return misclassified_dialog_box.update(visible=True)
else:
return misclassified_dialog_box.update(visible=False)
def get_images():
counter = 29
if images == []:
while counter>0:
image_path = f'Misclassified_images/{counter}.jpg'
images.append(image_path)
counter -=1
return images
def show_misclassified_images(number_of_missclassified, gradcam, transparency, target_layer):
images = get_images()
output_gallery = []
for image_path in images:
image = Image.open(image_path)
image_array = np.asarray(image)
visualization = inference(image_array, gradcam, transparency, target_layer)[-1]
output_gallery.append(visualization)
return {
misclassified_output_box: gr.update(visible=True),
gallery: output_gallery[:number_of_missclassified]
}
with gr.Blocks() as demo:
gr.Markdown("# Lighting DavidNet")
gr.Markdown("### CIFAR 10 Classifier with GradCAM with DavidNet")
gr.Markdown("## Classification")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(shape=(32, 32), label="Input Image")
with gr.Row():
clear_btn_main = gr.ClearButton()
submit_btn_main = gr.Button("Submit")
with gr.Accordion("Advanced options", open=False):
input_radio_gradcam = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay GradCAM output")
with gr.Column(visible=False) as gradcam_dialog_box:
input_slider1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")
input_slider2 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?")
input_slider_classes = gr.Slider(1, 10, value = 3, step=1, label="How Many Classes you want to see?")
input_radio_misclassification = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see misclassified images?")
with gr.Column(visible=False) as misclassified_dialog_box:
input_slider_misclassified = gr.Slider(1, 29, value = 29, step=1, label="Number of misclassified images to view?")
with gr.Column(scale=1):
output_classes = gr.Label(num_top_classes=3,label="Output Labels(Default: 3)")
output_image = gr.Image(shape=(32, 32), label="Classification Output(Default: Without GradCAM)").style(width=512, height=512)
with gr.Column(visible=True) as misclassified_output_box:
gallery = gr.Gallery(label="Misclassified Gallery", show_label=False, elem_id="gallery").style(columns=[5], rows=[6], object_fit="contain", height="auto")
submit_btn_main.click(
fn=inference, inputs=[
input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes,
input_radio_misclassification,input_slider_misclassified
],
outputs=[
output_classes,
output_image,
gallery
]
)
clear_btn_main.click(
lambda: [None, "No", 0.5, 3, 3,"No",3,3, None,None],
outputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes, input_radio_misclassification,input_slider_misclassified, output_classes, output_image, gallery])
input_slider_classes.change(update_top_classes, inputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes], outputs=[output_classes])
input_radio_gradcam.change(fn=change_gradcam_view, inputs=input_radio_gradcam, outputs=[gradcam_dialog_box])
input_radio_misclassification.change(fn=change_missclassified_view, inputs=input_radio_misclassification, outputs=[misclassified_dialog_box])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Examples")
gr.Examples(
examples=[["Examples/1.jpg", "Yes", 0.5, 3, 3,"Yes",29],
["Examples/2.jpg", "Yes", 0.7, 2, 5,"Yes",29],
["Examples/3.jpg", "Yes", 0.9, 1, 4,"Yes",29],
["Examples/4.jpg", "Yes", 0.3, 1, 7,"Yes",29],
["Examples/5.jpg", "Yes", 0.7, 3, 4,"Yes",29],
["Examples/6.jpg", "Yes", 0.8, 3, 6,"Yes",29],
["Examples/7.jpg", "Yes", 0.9, 1, 7,"Yes",29],
["Examples/8.jpg", "Yes", 0.3, 1, 3,"Yes",29],
["Examples/9.jpg", "Yes", 0.4, 3, 4,"Yes",29],
["Examples/10.jpg", "Yes", 0.5, 2, 5,"Yes",29]
],
inputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes,
input_radio_misclassification,input_slider_misclassified],
outputs=[output_classes, output_image,gallery],
fn=inference,
cache_examples=True,
)
if __name__ == "__main__":
demo.launch(debug=False)
# demo.launch(share=True,debug = True)