Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import torch | |
| from torchvision import datasets, transforms | |
| from custom_resnet import CustomResNet | |
| import random | |
| model = CustomResNet() | |
| model.load_state_dict(torch.load('CustomResNet.pth', map_location=torch.device('cpu')), strict=False) | |
| model.eval() | |
| classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck') | |
| def inference(input_img, input_slider_grad_or_not, transparency = 0.5, target_layer_number = 3, topk = 3): | |
| mean=[0.49139968, 0.48215827, 0.44653124] | |
| std=[0.24703233, 0.24348505, 0.26158768] | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean, std) | |
| ]) | |
| orginal_img = input_img | |
| input_img = transform(input_img) | |
| input_img = input_img.unsqueeze(0) | |
| outputs = model(input_img) | |
| softmax = torch.nn.Softmax(dim=0) | |
| o = softmax(outputs.flatten()) | |
| confidences = {classes[i]: float(o[i]) for i in range(10)} | |
| if input_slider_grad_or_not == "No": | |
| return confidences, orginal_img | |
| _, prediction = torch.max(outputs, 1) | |
| target_layers = [model.layer_3[-1]] | |
| if target_layer_number == 1: | |
| target_layers = [model.layer_1[-1]] | |
| if target_layer_number == 2: | |
| target_layers = [model.layer_2[-1]] | |
| if target_layer_number == 3: | |
| target_layers = [model.layer_3[-1]] | |
| cam = GradCAM(model=model, target_layers=target_layers) | |
| grayscale_cam = cam(input_tensor=input_img, targets=None) | |
| grayscale_cam = grayscale_cam[0, :] | |
| visualization = show_cam_on_image(orginal_img/255, grayscale_cam, use_rgb=True, image_weight=transparency) | |
| return confidences, visualization | |
| def show_gradcam_images(n, a, b): | |
| images = [ | |
| ('examples/car.jpg', 'car'), | |
| ('examples/cat.jpg', 'cat'), | |
| ('examples/dog.jpg', 'dog'), | |
| ('examples/horse.jpg', 'horse'), | |
| ('examples/ship.jpg', 'ship'), | |
| ('examples/bird.jpg', 'bird'), | |
| ('examples/frog.jpg', 'frog'), | |
| ('examples/plane.jpg', 'plane'), | |
| ('examples/truck.jpg', 'truck'), | |
| ('examples/deer.jpg', 'deer'), | |
| ] | |
| images_with_gradcam = [] | |
| for image_path, label in images: | |
| image = Image.open(image_path) | |
| image = image.resize((32, 32)) | |
| image_array = np.asarray(image) | |
| visualization = inference(image_array, "Yes", a, b)[-1] | |
| images_with_gradcam.append((visualization, label)) | |
| return { | |
| grad1_block: gr.update(visible=True), | |
| gallery3: images_with_gradcam[:n] | |
| } | |
| def change_grad_view(choice): | |
| if choice == "Yes": | |
| return grad_block.update(visible=True) | |
| else: | |
| return grad_block.update(visible=False) | |
| def show_misclassified_images(n, grad_cam, a, b): | |
| images = [ | |
| ('misclassified_images/misclassified_0_GT_bird_Pred_cat.jpg', 'bird/cat'), | |
| ('misclassified_images/misclassified_1_GT_car_Pred_truck.jpg', 'car/truck'), | |
| ('misclassified_images/misclassified_2_GT_plane_Pred_truck.jpg', 'plane/truck'), | |
| ('misclassified_images/misclassified_3_GT_deer_Pred_dog.jpg', 'deer/dog'), | |
| ('misclassified_images/misclassified_4_GT_frog_Pred_cat.jpg', 'frog/cat'), | |
| ('misclassified_images/misclassified_5_GT_cat_Pred_dog.jpg', 'cat/dog'), | |
| ('misclassified_images/misclassified_6_GT_cat_Pred_dog.jpg', 'cat/dog'), | |
| ('misclassified_images/misclassified_7_GT_dog_Pred_horse.jpg', 'dog/horse'), | |
| ('misclassified_images/misclassified_8_GT_bird_Pred_dog.jpg', 'bird/dog'), | |
| ('misclassified_images/misclassified_9_GT_ship_Pred_plane.jpg', 'ship/plane') | |
| ] | |
| images_with_gradcam = [] | |
| for image_path, label in images: | |
| image = Image.open(image_path) | |
| image_array = np.asarray(image) | |
| visualization = inference(image_array, "Yes", a, b)[-1] | |
| images_with_gradcam.append((visualization, label)) | |
| if grad_cam == "Yes": | |
| return { | |
| miscls1_block: gr.update(visible=True), | |
| gallery: images_with_gradcam[:n] | |
| } | |
| return { | |
| miscls1_block: gr.update(visible=True), | |
| gallery: images[:n] | |
| } | |
| def change_miscls_view(choice): | |
| if choice == "Yes": | |
| return miscls_block.update(visible=True) | |
| else: | |
| return miscls_block.update(visible=False) | |
| def change_textbox(choice): | |
| if choice == "Yes": | |
| return [gr.Slider.update(visible=True), gr.Slider.update(visible=True)] | |
| else: | |
| return [gr.Slider.update(visible=False), gr.Slider.update(visible=False)] | |
| def update_num_top_classes(input_img, input_slider_grad_or_not, transparency, target_layer_number, topk): | |
| output_classes.num_top_classes=topk | |
| return inference(input_img, input_slider_grad_or_not, transparency, target_layer_number, topk)[0] | |
| def change_mygrad_view(choice): | |
| if choice == "Yes": | |
| return grad_or_not.update(visible=True) | |
| else: | |
| return grad_or_not.update(visible=False) | |
| with gr.Blocks(theme='xiaobaiyuan/theme_brief') as demo: | |
| gr.Markdown(""" | |
| # CustomResNet model with GradCAM | |
| ### A simple Gradio interface to infer on CustomResNet model and get GradCAM results | |
| """) | |
| #gr.Markdown("# Model") | |
| gr.Markdown("## Grad-CAM Images") | |
| with gr.Row(): | |
| grad_yes_no = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see GradCAM images") | |
| with gr.Row(visible=False) as grad_block: | |
| with gr.Column(scale=1): | |
| input_grad = gr.Slider(1, 10, value = 5, step=1, label="Number of GradCAM images to view") | |
| input_overlay = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to configure gradcam") | |
| with gr.Row(): | |
| clear_btn3 = gr.ClearButton() | |
| submit_btn3 = gr.Button("Submit") | |
| with gr.Column(scale=1): | |
| input_slider1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM", interactive=True, visible=False) | |
| input_slider2 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?", interactive=True, visible=False) | |
| with gr.Row(visible=False) as grad1_block: | |
| gallery3 = gr.Gallery( | |
| label="GradCAM images", show_label=True, elem_id="gallery3" | |
| ).style(columns=[4], rows=[3], object_fit="contain", height="auto") | |
| submit_btn3.click(fn=show_gradcam_images, inputs=[input_grad, input_slider1, input_slider2], outputs = [grad1_block, gallery3]) | |
| clear_btn3.click(lambda: [None, None, None, None, None], outputs=[input_grad, input_grad, input_slider1, input_slider2, gallery3]) | |
| input_overlay.change(fn=change_textbox, inputs=input_overlay, outputs=[input_slider1, input_slider2]) | |
| grad_yes_no.change(fn=change_grad_view, inputs=grad_yes_no, outputs=[grad_block]) | |
| ############################################### | |
| gr.Markdown("## Misclassification Images") | |
| with gr.Row(): | |
| miscls_yes_no = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see misclassified images") | |
| with gr.Row(visible=False) as miscls_block: | |
| with gr.Column(scale=1): | |
| input_miscn = gr.Slider(1, 10, value = 3, step=1, label="Number of misclassified images to view") | |
| with gr.Column(scale=1): | |
| input_grad2 = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay gradcam") | |
| input_slider21 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM", interactive=True, visible=False) | |
| input_slider22 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?", interactive=True, visible=False) | |
| with gr.Row(): | |
| clear_btn2 = gr.ClearButton() | |
| submit_btn2 = gr.Button("Submit") | |
| with gr.Column(visible=False) as miscls1_block: | |
| gallery = gr.Gallery( | |
| label="Misclassified images", show_label=True, elem_id="gallery" | |
| ).style(columns=[4], rows=[3], object_fit="contain", height="auto") | |
| submit_btn2.click(fn=show_misclassified_images, inputs=[input_miscn, input_grad2, input_slider21, input_slider22], outputs = [miscls1_block, gallery]) | |
| clear_btn2.click(lambda: [None, None, None, None, None], outputs=[input_miscn, input_grad, input_slider21, input_slider22, gallery]) | |
| input_grad2.change(fn=change_textbox, inputs=input_grad2, outputs=[input_slider21, input_slider22]) | |
| miscls_yes_no.change(fn=change_miscls_view, inputs=miscls_yes_no, outputs=[miscls_block]) | |
| ############################################### | |
| gr.Markdown("## Input Interface ") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(shape=(32, 32), label="Input Image") | |
| input_topk = gr.Slider(1, 10, value = 3, step=1, label="Top N Classes") | |
| input_slider_grad_or_not = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay GradCAM output") | |
| with gr.Row(): | |
| clear_btn = gr.ClearButton() | |
| submit_btn = gr.Button("Submit") | |
| with gr.Column(visible=False) as grad_or_not: | |
| input_slider1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM") | |
| input_slider2 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?") | |
| with gr.Column(scale=1): | |
| output_classes = gr.Label(num_top_classes=3) | |
| output_image = gr.Image(shape=(32, 32), label="Output").style(width=128, height=128) | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=[["examples/car.jpg", "Yes", 0.5, 3, 3], | |
| ["examples/cat.jpg", "Yes", 0.7, 2, 5], | |
| ["examples/dog.jpg", "Yes", 0.9, 1, 4], | |
| ["examples/truck.jpg", "Yes", 0.3, 1, 7], | |
| ["examples/horse.jpg", "Yes", 0.7, 3, 4], | |
| ["examples/frog.jpg", "Yes", 0.8, 3, 6], | |
| ["examples/bird.jpg", "Yes", 0.9, 1, 7], | |
| ["examples/deer.jpg", "Yes", 0.3, 1, 3], | |
| ["examples/plane.jpg", "Yes", 0.4, 3, 4], | |
| ["examples/ship.jpg", "Yes", 0.5, 2, 5] | |
| ], | |
| inputs=[input_image,input_slider_grad_or_not,input_slider1,input_slider2, input_topk], | |
| outputs=[output_classes,output_image], | |
| fn=inference, | |
| cache_examples=True, | |
| ) | |
| submit_btn.click(fn=inference, inputs=[input_image, input_slider_grad_or_not, input_slider1, input_slider2, input_topk], outputs=[output_classes, output_image]) | |
| clear_btn.click(lambda: [None, "No", 0.5, 3, None, None, 3], outputs=[input_image, input_slider_grad_or_not, input_slider1, input_slider2, output_classes, output_image]) | |
| input_topk.change(update_num_top_classes, inputs=[input_image, input_slider_grad_or_not, input_slider1, input_slider2, input_topk], outputs=[output_classes]) | |
| input_slider_grad_or_not.change(fn=change_mygrad_view, inputs=input_slider_grad_or_not, outputs=[grad_or_not]) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |