Spaces:
Sleeping
Sleeping
| # Imports | |
| import os | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| import torch, torchvision | |
| from torchvision import transforms | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import gradio as gr | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import config | |
| import utils | |
| import config | |
| from torchvision import transforms | |
| import torch.optim as optim | |
| scaled_anchors = ( | |
| torch.tensor(config.ANCHORS) | |
| * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
| ).to('cpu') | |
| test_transforms_exp = A.Compose( | |
| [ | |
| A.LongestMaxSize(max_size=config.IMAGE_SIZE), | |
| A.PadIfNeeded( | |
| min_height=config.IMAGE_SIZE, min_width=config.IMAGE_SIZE | |
| ), | |
| A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), | |
| ToTensorV2(), | |
| ] | |
| ) | |
| num_classes = 20 | |
| IMAGE_SIZE = 416 | |
| from model import YOLOv3 | |
| import config | |
| import torch | |
| model = YOLOv3(num_classes=num_classes) | |
| model.load_state_dict(torch.load(r"./model_40.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| classes = config.PASCAL_CLASSES | |
| def plot_image(image, boxes): | |
| """Plots predicted bounding boxes on the image""" | |
| cmap = plt.get_cmap("tab20b") | |
| class_labels = config.COCO_LABELS if config.DATASET=='COCO' else config.PASCAL_CLASSES | |
| colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))] | |
| im = np.array(image) | |
| height, width, _ = im.shape | |
| # Create figure and axes | |
| fig, ax = plt.subplots(1) | |
| # Display the image | |
| ax.imshow(im) | |
| # box[0] is x midpoint, box[2] is width | |
| # box[1] is y midpoint, box[3] is height | |
| # Create a Rectangle patch | |
| for box in boxes: | |
| assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height" | |
| class_pred = box[0] | |
| box = box[2:] | |
| upper_left_x = box[0] - box[2] / 2 | |
| upper_left_y = box[1] - box[3] / 2 | |
| rect = patches.Rectangle( | |
| (upper_left_x * width, upper_left_y * height), | |
| box[2] * width, | |
| box[3] * height, | |
| linewidth=2, | |
| edgecolor=colors[int(class_pred)], | |
| facecolor="none", | |
| ) | |
| # Add the patch to the Axes | |
| ax.add_patch(rect) | |
| plt.text( | |
| upper_left_x * width, | |
| upper_left_y * height, | |
| s=class_labels[int(class_pred)], | |
| color="white", | |
| verticalalignment="top", | |
| bbox={"color": colors[int(class_pred)], "pad": 0}, | |
| ) | |
| fig.canvas.draw() | |
| image_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8') | |
| image_array = image_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| return image_array | |
| class outCallBack: | |
| def __init__(self): | |
| pass | |
| def __call__(self, out): | |
| in_shape=np.unravel_index(torch.argmax(out[..., :1]), out[..., :1].shape) | |
| arg_max = out[in_shape[:-1]][5:].max() | |
| return arg_max | |
| def inference(input_img, transparency = 0.5, iou = 0.5, threshold = 0.5): | |
| transform = test_transforms_exp(image=input_img) | |
| trans_img = transform['image'].unsqueeze(0) | |
| with torch.no_grad(): | |
| out = model(trans_img) | |
| bboxes = [[] for _ in range(trans_img.shape[0])] | |
| for i in range(3): | |
| batch_size, A, S, _, _ = out[i].shape | |
| anchor = scaled_anchors[i] | |
| boxes_scale_i = utils.cells_to_bboxes( | |
| out[i], anchor, S=S, is_preds=True | |
| ) | |
| for idx, (box) in enumerate(boxes_scale_i): | |
| bboxes[idx] += box | |
| nms_boxes = utils.non_max_suppression( | |
| bboxes[0], iou_threshold=iou, threshold=threshold, box_format="midpoint", | |
| ) | |
| out_fig = plot_image(trans_img.squeeze().permute(1,2,0).detach().cpu(), nms_boxes) | |
| cam = GradCAM(model, [model.layers[12].conv], use_cuda=False) | |
| grayscale_cam = cam(trans_img, targets=[outCallBack()])[0, :] | |
| cam_image = show_cam_on_image(input_img.astype(np.float32)/255, grayscale_cam, use_rgb=True,image_weight=transparency) | |
| return out_fig, cam_image | |
| title = "YOLO V3 trained on PASCAL VOC Dataset" | |
| description = "Gradio interface to show yoloV3 object detection and gradcam on outputs." | |
| examples = [[f'examples/{i}'] for i in os.listdir("examples")] | |
| demo = gr.Interface( | |
| inference, | |
| inputs = [gr.Image(shape=(416, 416), label="Input Image"), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(0, 1, value = 0.5, label="IOU Value"), gr.Slider(0, 1, value = 0.5, label="Threshold Value")], | |
| outputs = [gr.Image(label="YoloV3 Output", shape = (416, 416)), gr.Image(label="GradCam Output", shape = (416, 416))], | |
| title = title, | |
| description = description, | |
| examples = examples, | |
| ) | |
| demo.launch() |