import numpy as np import gradio as gr from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image import torch from torchvision import datasets, transforms from model import LightningDavidNet import random model = LightningDavidNet() model.load_from_checkpoint('model.pt') model.eval() classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') images = [] def run_model(input_img, input_radio_gradcam, transparency = 0.5, target_layer = 3, input_slider_classes = 3): mean=[0.49139968, 0.48215827, 0.44653124] std=[0.24703233, 0.24348505, 0.26158768] transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std) ]) orginal_img = input_img input_img = transform(input_img) input_img = input_img.unsqueeze(0) outputs = model(input_img) softmax = torch.nn.Softmax(dim=0) o = softmax(outputs.flatten()) confidences = {classes[i]: float(o[i]) for i in range(10)} if input_radio_gradcam == "No": return confidences, orginal_img _, prediction = torch.max(outputs, 1) target_layers = [model.r2.block1[0]] if target_layer == 1: target_layers = [model.l2X[0]] if target_layer == 2: target_layers = [model.l3X[0]] if target_layer == 3: target_layers = [model.r2.block1[0]] cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) grayscale_cam = cam(input_tensor=input_img, targets=None) grayscale_cam = grayscale_cam[0, :] visualization = show_cam_on_image(orginal_img/255, grayscale_cam, use_rgb=True, image_weight=transparency) return confidences, visualization def inference(input_img, input_radio_gradcam, transparency = 0.5, target_layer = 3, input_slider_classes = 3, input_radio_misclassification="No",input_slider_misclassified=29): confidences, visualization = run_model(input_img, input_radio_gradcam, transparency, target_layer, input_slider_classes) if input_radio_misclassification =="Yes": images = get_images() misclassified_output_box.visible = True return confidences, visualization,images[:input_slider_misclassified] else: return confidences, visualization,None def change_gradcam_view(choice): if choice == "Yes": return gradcam_dialog_box.update(visible=True) else: return gradcam_dialog_box.update(visible=False) def update_top_classes(input_img, input_slider_gradcam, transparency, target_layer_number, topk): output_classes.num_top_classes=topk return inference(input_img, input_slider_gradcam, transparency, target_layer_number, topk)[0] def change_missclassified_view(choice): if choice == "Yes": return misclassified_dialog_box.update(visible=True) else: return misclassified_dialog_box.update(visible=False) def get_images(): counter = 29 if images == []: while counter>0: image_path = f'Misclassified_images/{counter}.jpg' images.append(image_path) counter -=1 return images def show_misclassified_images(number_of_missclassified, gradcam, transparency, target_layer): images = get_images() output_gallery = [] for image_path in images: image = Image.open(image_path) image_array = np.asarray(image) visualization = inference(image_array, gradcam, transparency, target_layer)[-1] output_gallery.append(visualization) return { misclassified_output_box: gr.update(visible=True), gallery: output_gallery[:number_of_missclassified] } with gr.Blocks() as demo: gr.Markdown("# Lighting DavidNet") gr.Markdown("### CIFAR 10 Classifier with GradCAM with DavidNet") gr.Markdown("## Classification") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(shape=(32, 32), label="Input Image") with gr.Row(): clear_btn_main = gr.ClearButton() submit_btn_main = gr.Button("Submit") with gr.Accordion("Advanced options", open=False): input_radio_gradcam = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay GradCAM output") with gr.Column(visible=False) as gradcam_dialog_box: input_slider1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM") input_slider2 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?") input_slider_classes = gr.Slider(1, 10, value = 3, step=1, label="How Many Classes you want to see?") input_radio_misclassification = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see misclassified images?") with gr.Column(visible=False) as misclassified_dialog_box: input_slider_misclassified = gr.Slider(1, 29, value = 29, step=1, label="Number of misclassified images to view?") with gr.Column(scale=1): output_classes = gr.Label(num_top_classes=3,label="Output Labels(Default: 3)") output_image = gr.Image(shape=(32, 32), label="Classification Output(Default: Without GradCAM)").style(width=512, height=512) with gr.Column(visible=True) as misclassified_output_box: gallery = gr.Gallery(label="Misclassified Gallery", show_label=False, elem_id="gallery").style(columns=[5], rows=[6], object_fit="contain", height="auto") submit_btn_main.click( fn=inference, inputs=[ input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes, input_radio_misclassification,input_slider_misclassified ], outputs=[ output_classes, output_image, gallery ] ) clear_btn_main.click( lambda: [None, "No", 0.5, 3, 3,"No",3,3, None,None], outputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes, input_radio_misclassification,input_slider_misclassified, output_classes, output_image, gallery]) input_slider_classes.change(update_top_classes, inputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes], outputs=[output_classes]) input_radio_gradcam.change(fn=change_gradcam_view, inputs=input_radio_gradcam, outputs=[gradcam_dialog_box]) input_radio_misclassification.change(fn=change_missclassified_view, inputs=input_radio_misclassification, outputs=[misclassified_dialog_box]) with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Examples") gr.Examples( examples=[["Examples/1.jpg", "Yes", 0.5, 3, 3,"Yes",29], ["Examples/2.jpg", "Yes", 0.7, 2, 5,"Yes",29], ["Examples/3.jpg", "Yes", 0.9, 1, 4,"Yes",29], ["Examples/4.jpg", "Yes", 0.3, 1, 7,"Yes",29], ["Examples/5.jpg", "Yes", 0.7, 3, 4,"Yes",29], ["Examples/6.jpg", "Yes", 0.8, 3, 6,"Yes",29], ["Examples/7.jpg", "Yes", 0.9, 1, 7,"Yes",29], ["Examples/8.jpg", "Yes", 0.3, 1, 3,"Yes",29], ["Examples/9.jpg", "Yes", 0.4, 3, 4,"Yes",29], ["Examples/10.jpg", "Yes", 0.5, 2, 5,"Yes",29] ], inputs=[input_image, input_radio_gradcam, input_slider1, input_slider2, input_slider_classes, input_radio_misclassification,input_slider_misclassified], outputs=[output_classes, output_image,gallery], fn=inference, cache_examples=True, ) if __name__ == "__main__": demo.launch(debug=False) # demo.launch(share=True,debug = True)