Spaces:
Build error
Build error
| #app.py | |
| import gradio as gr | |
| from pytorch_lightning import LightningModule | |
| import math | |
| import numpy as np | |
| from PIL import Image | |
| from helper_fn import * | |
| import pickle | |
| model = get_model('cifar10_resnet_epochs30.ckpt', map_location='cpu') | |
| target_layers = [model.layer3[-1]] | |
| targets = None | |
| # load the pickle file | |
| with open('required_data.pkl', 'rb') as f: | |
| loaded_pkl = pickle.load(f) | |
| def inference(image, show_grad_cam=True, opacity=0.7, grad_cam_layer=3, top_preds=10, show_misclassified_imgs=5): | |
| image = img_test_transforms(image) | |
| image = image.unsqueeze(0) | |
| preds = model(image).squeeze() | |
| confidences = {classes[i]: float(preds[i]) for i in range(10)} | |
| sorted_confidences = sorted(confidences.items(), key=lambda x: x[1], reverse=True)[:top_preds] | |
| f_confidences = dict(sorted_confidences) | |
| if grad_cam_layer==1: | |
| target_layers = [model.layer1[-1]] | |
| elif grad_cam_layer == 2: | |
| target_layers = [model.layer2[-1]] | |
| else: | |
| target_layers = [model.layer3[-1]] | |
| cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) | |
| if show_grad_cam: | |
| grayscale_cam = cam(input_tensor=image, targets=targets) | |
| grayscale_cam = grayscale_cam[0, :] | |
| # Get back the original image | |
| img = image.squeeze(0).to('cpu') | |
| img = inv_normalize(img) | |
| rgb_img = np.transpose(img, (1, 2, 0)) | |
| rgb_img = rgb_img.numpy() | |
| max_value = np.max(rgb_img) | |
| rgb_img /= max_value | |
| # Mix the activations on the original image | |
| # rgb_img = img_test_transforms(rgb_img) | |
| print(opacity) | |
| visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=opacity) | |
| else: | |
| visualization = None | |
| if show_misclassified_imgs: | |
| grid = [] | |
| for f in loaded_pkl: | |
| image, true_label, pred_label = f | |
| image = image.squeeze().to('cpu') | |
| image = inv_normalize(image) | |
| rgb_img = np.transpose(image, (1, 2, 0)) | |
| rgb_img = rgb_img.numpy() | |
| grid.append((rgb_img, f"predicted_label: {classes[pred_label]}, true_label: {classes[true_label]}")) | |
| return f_confidences, visualization, grid[:show_misclassified_imgs] | |
| demo = gr.Interface(fn = inference, | |
| inputs = [gr.Image(type='pil', shape=(32, 32), label='Image'), | |
| gr.CheckboxGroup(['True'], label='Show GradCAM'), | |
| gr.Slider(0,1), | |
| gr.Slider(1, 3, value=3, step=1), | |
| gr.Slider(1, 10, value=10, step=1), | |
| gr.Slider(1, 10, value=10, step=1) | |
| ], | |
| outputs=[gr.Label(num_top_classes=len(classes), label='Classification Outputs'), | |
| gr.Image(type='pil', label='GradCAM', shape=(32, 32)).style(width=500, height=400), | |
| gr.Gallery(shape=(32, 32)).style(columns=2, rows=5) | |
| ], | |
| examples=[['truck.jpg'],['ship.jpg'],['car.jpg'],['horse.jpg'],['frog.jpeg'], | |
| ['dog.jpg'],['deer.jpg'],['cat.jpg'], ['bird.jpg'], ['airplane.jpg'] | |
| ]) | |
| demo.launch() | |