Spaces:
Sleeping
Sleeping
| import torch, torchvision | |
| from torchvision import transforms | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import gradio as gr | |
| import os | |
| from helper import CifarAlbumentations, get_train_transforms, get_test_transforms | |
| from resnet import CustomResNet | |
| config = { | |
| 'batch_size': 512, | |
| 'data_dir': './data', | |
| 'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], | |
| 'num_classes': 10, | |
| 'lr': 0.01, | |
| 'max_lr': 0.1, | |
| 'max_lr_epoch': 5, | |
| 'dropout' : 0.01, | |
| 'LEARNING_RATE' : 1e-5, | |
| 'WEIGHT_DECAY' : 1e-4, | |
| 'NUM_EPOCHS' : 100 | |
| } | |
| train_transforms = get_train_transforms() | |
| test_transforms = get_test_transforms() | |
| model = CustomResNet(config, config['dropout'], train_transforms, test_transforms) | |
| model.load_state_dict(torch.load("resnet_model_v7.pth", map_location=torch.device('cpu')), strict=False) | |
| model.setup(stage="test") | |
| inv_normalize = transforms.Normalize( | |
| mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], | |
| std=[1/0.23, 1/0.23, 1/0.23] | |
| ) | |
| classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck') | |
| classes_for_categorize = {0: 'plane', | |
| 1: 'car', | |
| 2: 'bird', | |
| 3: 'cat', | |
| 4: 'deer', | |
| 5: 'dog', | |
| 6: 'frog', | |
| 7: 'horse', | |
| 8: 'ship', | |
| 9: 'truck'} | |
| def inference(input_img, transparency=0.5, target_layer_number=-1, top_classes=10): | |
| transform = transforms.ToTensor() | |
| org_img = input_img | |
| input_img = transform(input_img) | |
| input_img = input_img | |
| input_img = input_img.unsqueeze(0) | |
| outputs = model(input_img) | |
| softmax = torch.nn.Softmax(dim=0) | |
| o = softmax(outputs.flatten()) | |
| confidences = {classes[i]: float(o[i]) for i in range(10)} | |
| sorted_classes = sorted(confidences.items(), key=lambda x: x[1], reverse=True) | |
| top_classes = sorted_classes[:top_classes] | |
| top_classes_dict = {cls: conf for cls, conf in top_classes} | |
| _, prediction = torch.max(outputs, 1) | |
| target_layers = [model.conv2] | |
| cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) | |
| grayscale_cam = cam(input_tensor=input_img, targets=None) | |
| grayscale_cam = grayscale_cam[0, :] | |
| img = input_img.squeeze(0) | |
| img = inv_normalize(img) | |
| rgb_img = np.transpose(img, (1, 2, 0)) | |
| rgb_img = rgb_img.numpy() | |
| visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency) | |
| return top_classes_dict, visualization | |
| def show_misclassified_images_wrap(num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5): | |
| transparency = float(transparency) | |
| num_images = int(num_images) | |
| if use_gradcam == "Yes": | |
| use_gradcam = True | |
| else: | |
| use_gradcam = False | |
| return model.show_misclassified_images(num_images, use_gradcam, gradcam_layer, transparency) | |
| title = "CIFAR10 Image Classification" | |
| description = "Upload an Image or Choose from Examples Below" | |
| images_folder = "examples" | |
| # Define the examples list with full paths | |
| examples = [[os.path.join(images_folder, "plane.jpg"), 0.5, -1,10], | |
| [os.path.join(images_folder, "car.jpg"), 0.5, -1,5], | |
| [os.path.join(images_folder, "bird.jpg"), 0.5, -1,3], | |
| [os.path.join(images_folder, "cat.jpg"), 0.5, -1, 5], | |
| [os.path.join(images_folder, "deer.jpg"), 0.5, -1,7], | |
| [os.path.join(images_folder, "dog.jpg"), 0.5, -1,6], | |
| [os.path.join(images_folder, "frog.jpg"), 0.5, -1,2], | |
| [os.path.join(images_folder, "horse.jpg"), 0.5, -1,10], | |
| [os.path.join(images_folder, "ship.jpg"), 0.5, -1,10], | |
| [os.path.join(images_folder, "truck.jpeg"), 0.5, -1,10]] | |
| # Create the input interface with the modified template | |
| input_interface = gr.Interface( | |
| inference, | |
| inputs=[ | |
| gr.Image(shape=(32, 32), label="Input Image"), | |
| gr.Slider(0, 1, value=0.5, label="Transparency", info="Set the Opacity of CAM"), | |
| gr.Slider(-2, -1, value=-2, step=1, label="Network Layer", info="GradCAM Network Layer"), | |
| gr.Slider(1, 10, step=1, value=10, label="Top Classes", info="How many top classes do you want to view") | |
| ], | |
| outputs=[ | |
| gr.Label(num_top_classes=10), | |
| gr.Image(shape=(32, 32), label="Model Prediction").style(width=300, height=300) | |
| ], | |
| description=description, | |
| examples=[[f'examples/{k}.jpg'] for k in classes_for_categorize.values()],) | |
| mislclassified_description = "Misclassified Image for Custom Resnet" | |
| icon_html = '<i class="fas fa-chart-bar"></i>' | |
| title_with_icon = f""" | |
| <div style="background-color: #f1f4f0; padding: 10px; display: flex; align-items: center;"> | |
| {icon_html} <span style="margin-left: 10px;">Custom Resnet on CIFAR10 using PyTorch Lightning and GradCAM</span> | |
| </div> | |
| """ | |
| # Create a separate interface for the "Misclassified Images" | |
| misclassified_interface = gr.Interface(show_misclassified_images_wrap, | |
| inputs=[gr.Number(value=10, label="Misclassified Inputs",info = "Set the Number of Misclassifed Outputs to be Shown"), | |
| gr.Radio(["Yes", "No"], value="No" , label="Enable GradCAM",info = "Do you want to see GradCAM"), | |
| gr.Slider(-2, -1, value=-1, step=1, label="Network Layer",info = "GradCAM Network Layer"), | |
| gr.Slider(0, 1, value=0.5, label="Transparency",info = "Set the Opacity of CAM")], | |
| outputs=gr.Plot(), description=mislclassified_description) | |
| demo = gr.TabbedInterface([input_interface, misclassified_interface], tab_names=["Top Classes and Prediction", "Misclassified Images"], | |
| title=title_with_icon,) | |
| demo.launch() | |