cifar10_resnet / app.py
rohithb's picture
Update app.py
25cc9d6
#app.py
import gradio as gr
from pytorch_lightning import LightningModule
import math
import numpy as np
from PIL import Image
from helper_fn import *
import pickle
model = get_model('cifar10_resnet_epochs30.ckpt', map_location='cpu')
target_layers = [model.layer3[-1]]
targets = None
# load the pickle file
with open('required_data.pkl', 'rb') as f:
loaded_pkl = pickle.load(f)
def inference(image, show_grad_cam=True, opacity=0.7, grad_cam_layer=3, top_preds=10, show_misclassified_imgs=5):
image = img_test_transforms(image)
image = image.unsqueeze(0)
preds = model(image).squeeze()
confidences = {classes[i]: float(preds[i]) for i in range(10)}
sorted_confidences = sorted(confidences.items(), key=lambda x: x[1], reverse=True)[:top_preds]
f_confidences = dict(sorted_confidences)
if grad_cam_layer==1:
target_layers = [model.layer1[-1]]
elif grad_cam_layer == 2:
target_layers = [model.layer2[-1]]
else:
target_layers = [model.layer3[-1]]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
if show_grad_cam:
grayscale_cam = cam(input_tensor=image, targets=targets)
grayscale_cam = grayscale_cam[0, :]
# Get back the original image
img = image.squeeze(0).to('cpu')
img = inv_normalize(img)
rgb_img = np.transpose(img, (1, 2, 0))
rgb_img = rgb_img.numpy()
max_value = np.max(rgb_img)
rgb_img /= max_value
# Mix the activations on the original image
# rgb_img = img_test_transforms(rgb_img)
print(opacity)
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=opacity)
else:
visualization = None
if show_misclassified_imgs:
grid = []
for f in loaded_pkl:
image, true_label, pred_label = f
image = image.squeeze().to('cpu')
image = inv_normalize(image)
rgb_img = np.transpose(image, (1, 2, 0))
rgb_img = rgb_img.numpy()
grid.append((rgb_img, f"predicted_label: {classes[pred_label]}, true_label: {classes[true_label]}"))
return f_confidences, visualization, grid[:show_misclassified_imgs]
demo = gr.Interface(fn = inference,
inputs = [gr.Image(type='pil', shape=(32, 32), label='Image'),
gr.CheckboxGroup(['True'], label='Show GradCAM'),
gr.Slider(0,1),
gr.Slider(1, 3, value=3, step=1),
gr.Slider(1, 10, value=10, step=1),
gr.Slider(1, 10, value=10, step=1)
],
outputs=[gr.Label(num_top_classes=len(classes), label='Classification Outputs'),
gr.Image(type='pil', label='GradCAM', shape=(32, 32)).style(width=500, height=400),
gr.Gallery(shape=(32, 32)).style(columns=2, rows=5)
],
examples=[['truck.jpg'],['ship.jpg'],['car.jpg'],['horse.jpg'],['frog.jpeg'],
['dog.jpg'],['deer.jpg'],['cat.jpg'], ['bird.jpg'], ['airplane.jpg']
])
demo.launch()