# Imports import os import matplotlib.pyplot as plt import matplotlib.patches as patches import torch, torchvision from torchvision import transforms import numpy as np import gradio as gr from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image import gradio as gr import albumentations as A from albumentations.pytorch import ToTensorV2 import config import utils import config from torchvision import transforms import torch.optim as optim scaled_anchors = ( torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) ).to('cpu') test_transforms_exp = A.Compose( [ A.LongestMaxSize(max_size=config.IMAGE_SIZE), A.PadIfNeeded( min_height=config.IMAGE_SIZE, min_width=config.IMAGE_SIZE ), A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), ToTensorV2(), ] ) num_classes = 20 IMAGE_SIZE = 416 from model import YOLOv3 import config import torch model = YOLOv3(num_classes=num_classes) model.load_state_dict(torch.load(r"./model_40.pth", map_location=torch.device('cpu'))) model.eval() classes = config.PASCAL_CLASSES def plot_image(image, boxes): """Plots predicted bounding boxes on the image""" cmap = plt.get_cmap("tab20b") class_labels = config.COCO_LABELS if config.DATASET=='COCO' else config.PASCAL_CLASSES colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))] im = np.array(image) height, width, _ = im.shape # Create figure and axes fig, ax = plt.subplots(1) # Display the image ax.imshow(im) # box[0] is x midpoint, box[2] is width # box[1] is y midpoint, box[3] is height # Create a Rectangle patch for box in boxes: assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height" class_pred = box[0] box = box[2:] upper_left_x = box[0] - box[2] / 2 upper_left_y = box[1] - box[3] / 2 rect = patches.Rectangle( (upper_left_x * width, upper_left_y * height), box[2] * width, box[3] * height, linewidth=2, edgecolor=colors[int(class_pred)], facecolor="none", ) # Add the patch to the Axes ax.add_patch(rect) plt.text( upper_left_x * width, upper_left_y * height, s=class_labels[int(class_pred)], color="white", verticalalignment="top", bbox={"color": colors[int(class_pred)], "pad": 0}, ) fig.canvas.draw() image_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8') image_array = image_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return image_array class outCallBack: def __init__(self): pass def __call__(self, out): in_shape=np.unravel_index(torch.argmax(out[..., :1]), out[..., :1].shape) arg_max = out[in_shape[:-1]][5:].max() return arg_max def inference(input_img, transparency = 0.5, iou = 0.5, threshold = 0.5): transform = test_transforms_exp(image=input_img) trans_img = transform['image'].unsqueeze(0) with torch.no_grad(): out = model(trans_img) bboxes = [[] for _ in range(trans_img.shape[0])] for i in range(3): batch_size, A, S, _, _ = out[i].shape anchor = scaled_anchors[i] boxes_scale_i = utils.cells_to_bboxes( out[i], anchor, S=S, is_preds=True ) for idx, (box) in enumerate(boxes_scale_i): bboxes[idx] += box nms_boxes = utils.non_max_suppression( bboxes[0], iou_threshold=iou, threshold=threshold, box_format="midpoint", ) out_fig = plot_image(trans_img.squeeze().permute(1,2,0).detach().cpu(), nms_boxes) cam = GradCAM(model, [model.layers[12].conv], use_cuda=False) grayscale_cam = cam(trans_img, targets=[outCallBack()])[0, :] cam_image = show_cam_on_image(input_img.astype(np.float32)/255, grayscale_cam, use_rgb=True,image_weight=transparency) return out_fig, cam_image title = "YOLO V3 trained on PASCAL VOC Dataset" description = "Gradio interface to show yoloV3 object detection and gradcam on outputs." examples = [[f'examples/{i}'] for i in os.listdir("examples")] demo = gr.Interface( inference, inputs = [gr.Image(shape=(416, 416), label="Input Image"), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(0, 1, value = 0.5, label="IOU Value"), gr.Slider(0, 1, value = 0.5, label="Threshold Value")], outputs = [gr.Image(label="YoloV3 Output", shape = (416, 416)), gr.Image(label="GradCam Output", shape = (416, 416))], title = title, description = description, examples = examples, ) demo.launch()