Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import torch | |
| from torchvision import datasets, transforms | |
| from model import LightningDavidNet | |
| import random | |
| model = LightningDavidNet() | |
| model.load_from_checkpoint('model.pt') | |
| model.eval() | |
| classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck') | |
| images = [] | |
| def run_model(input_img, input_radio_gradcam, transparency = 0.5, target_layer = 3, input_slider_classes = 3): | |
| mean=[0.49139968, 0.48215827, 0.44653124] | |
| std=[0.24703233, 0.24348505, 0.26158768] | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean, std) | |
| ]) | |
| orginal_img = input_img | |
| input_img = transform(input_img) | |
| input_img = input_img.unsqueeze(0) | |
| outputs = model(input_img) | |
| softmax = torch.nn.Softmax(dim=0) | |
| o = softmax(outputs.flatten()) | |
| confidences = {classes[i]: float(o[i]) for i in range(10)} | |
| if input_radio_gradcam == "No": | |
| return confidences, orginal_img | |
| _, prediction = torch.max(outputs, 1) | |
| target_layers = [model.r2.block1[0]] | |
| if target_layer == 1: | |
| target_layers = [model.l2X[0]] | |
| if target_layer == 2: | |
| target_layers = [model.l3X[0]] | |
| if target_layer == 3: | |
| target_layers = [model.r2.block1[0]] | |
| cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) | |
| grayscale_cam = cam(input_tensor=input_img, targets=None) | |
| grayscale_cam = grayscale_cam[0, :] | |
| visualization = show_cam_on_image(orginal_img/255, grayscale_cam, use_rgb=True, image_weight=transparency) | |
| return confidences, visualization | |
| def inference(input_img, input_radio_gradcam, transparency = 0.5, target_layer = 3, input_slider_classes = 3, input_radio_misclassification="No",input_slider_misclassified=29): | |
| confidences, visualization = run_model(input_img, input_radio_gradcam, transparency, target_layer, input_slider_classes) | |
| if input_radio_misclassification =="Yes": | |
| images = get_images() | |
| misclassified_output_box.visible = True | |
| return confidences, visualization,images[:input_slider_misclassified] | |
| else: | |
| return confidences, visualization,None | |
| def change_gradcam_view(choice): | |
| if choice == "Yes": | |
| return gradcam_dialog_box.update(visible=True) | |
| else: | |
| return gradcam_dialog_box.update(visible=False) | |
| def update_top_classes(input_img, input_slider_gradcam, transparency, target_layer_number, topk): | |
| output_classes.num_top_classes=topk | |
| return inference(input_img, input_slider_gradcam, transparency, target_layer_number, topk)[0] | |
| def change_missclassified_view(choice): | |
| if choice == "Yes": | |
| return misclassified_dialog_box.update(visible=True) | |
| else: | |
| return misclassified_dialog_box.update(visible=False) | |
| def get_images(): | |
| counter = 29 | |
| if images == []: | |
| while counter>0: | |
| image_path = f'Misclassified_images/{counter}.jpg' | |
| images.append(image_path) | |
| counter -=1 | |
| return images | |
| def show_misclassified_images(number_of_missclassified, gradcam, transparency, target_layer): | |
| images = get_images() | |
| output_gallery = [] | |
| for image_path in images: | |
| image = Image.open(image_path) | |
| image_array = np.asarray(image) | |
| visualization = inference(image_array, gradcam, transparency, target_layer)[-1] | |
| output_gallery.append(visualization) | |
| return { | |
| misclassified_output_box: gr.update(visible=True), | |
| gallery: output_gallery[:number_of_missclassified] | |
| } | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Lighting DavidNet") | |
| gr.Markdown("### CIFAR 10 Classifier with GradCAM with DavidNet") | |
| gr.Markdown("## Classification") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(shape=(32, 32), label="Input Image") | |
| with gr.Row(): | |
| clear_btn_main = gr.ClearButton() | |
| submit_btn_main = gr.Button("Submit") | |
| with gr.Accordion("Advanced options", open=False): | |
| input_radio_gradcam = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay GradCAM output") | |
| with gr.Column(visible=False) as gradcam_dialog_box: | |
| input_slider1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM") | |
| input_slider2 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?") | |
| input_slider_classes = gr.Slider(1, 10, value = 3, step=1, label="How Many Classes you want to see?") | |
| input_radio_misclassification = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see misclassified images?") | |
| with gr.Column(visible=False) as misclassified_dialog_box: | |
| input_slider_misclassified = gr.Slider(1, 29, value = 29, step=1, label="Number of misclassified images to view?") | |
| with gr.Column(scale=1): | |
| output_classes = gr.Label(num_top_classes=3,label="Output Labels(Default: 3)") | |
| output_image = gr.Image(shape=(32, 32), label="Classification Output(Default: Without GradCAM)").style(width=512, height=512) | |
| with gr.Column(visible=True) as misclassified_output_box: | |
| gallery = gr.Gallery(label="Misclassified Gallery", show_label=False, elem_id="gallery").style(columns=[5], rows=[6], object_fit="contain", height="auto") | |
| submit_btn_main.click( | |
| fn=inference, inputs=[ | |
| input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes, | |
| input_radio_misclassification,input_slider_misclassified | |
| ], | |
| outputs=[ | |
| output_classes, | |
| output_image, | |
| gallery | |
| ] | |
| ) | |
| clear_btn_main.click( | |
| lambda: [None, "No", 0.5, 3, 3,"No",3,3, None,None], | |
| outputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes, input_radio_misclassification,input_slider_misclassified, output_classes, output_image, gallery]) | |
| input_slider_classes.change(update_top_classes, inputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes], outputs=[output_classes]) | |
| input_radio_gradcam.change(fn=change_gradcam_view, inputs=input_radio_gradcam, outputs=[gradcam_dialog_box]) | |
| input_radio_misclassification.change(fn=change_missclassified_view, inputs=input_radio_misclassification, outputs=[misclassified_dialog_box]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=[["Examples/1.jpg", "Yes", 0.5, 3, 3,"Yes",29], | |
| ["Examples/2.jpg", "Yes", 0.7, 2, 5,"Yes",29], | |
| ["Examples/3.jpg", "Yes", 0.9, 1, 4,"Yes",29], | |
| ["Examples/4.jpg", "Yes", 0.3, 1, 7,"Yes",29], | |
| ["Examples/5.jpg", "Yes", 0.7, 3, 4,"Yes",29], | |
| ["Examples/6.jpg", "Yes", 0.8, 3, 6,"Yes",29], | |
| ["Examples/7.jpg", "Yes", 0.9, 1, 7,"Yes",29], | |
| ["Examples/8.jpg", "Yes", 0.3, 1, 3,"Yes",29], | |
| ["Examples/9.jpg", "Yes", 0.4, 3, 4,"Yes",29], | |
| ["Examples/10.jpg", "Yes", 0.5, 2, 5,"Yes",29] | |
| ], | |
| inputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes, | |
| input_radio_misclassification,input_slider_misclassified], | |
| outputs=[output_classes, output_image,gallery], | |
| fn=inference, | |
| cache_examples=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=False) | |
| # demo.launch(share=True,debug = True) | |