Spaces:
Sleeping
Sleeping
| ''' | |
| Efficientdet demo | |
| ''' | |
| import argparse | |
| import cv2 | |
| import os | |
| import time | |
| from PIL import Image | |
| import PIL.ImageColor as ImageColor | |
| import requests | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torchvision.transforms as T | |
| from tqdm import tqdm | |
| from effdet import create_model | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser( | |
| 'Test detr on one image') | |
| parser.add_argument( | |
| '--img', metavar='IMG', | |
| help='path to image, could be url', | |
| default='https://www.fyidenmark.com/images/denmark-litter.jpg') | |
| parser.add_argument( | |
| '--save', metavar='OUTPUT', | |
| help='path to save image with predictions (if None show image)', | |
| default=None) | |
| parser.add_argument('--classes', nargs='+', default=['Litter']) | |
| parser.add_argument( | |
| '--checkpoint', type=str, | |
| help='path to checkpoint') | |
| parser.add_argument( | |
| '--device', type=str, default='cpu', | |
| help='device to evaluate model (default: cpu)') | |
| parser.add_argument( | |
| '--prob_threshold', type=float, default=0.3, | |
| help='probability threshold to show results (default: 0.5)') | |
| parser.add_argument( | |
| '--video', action='store_true', default=False, | |
| help="If true, we treat impute as video (default: False)") | |
| parser.set_defaults(redundant_bias=None) | |
| return parser | |
| # standard PyTorch mean-std input image normalization | |
| def get_transforms(im, size=768): | |
| transform = T.Compose([ | |
| T.Resize((size, size)), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(im).unsqueeze(0) | |
| def rescale_bboxes(out_bbox, size, resize): | |
| img_w, img_h = size | |
| out_w, out_h = resize | |
| b = out_bbox * torch.tensor([img_w/out_w, img_h/out_h, | |
| img_w/out_w, img_h/out_h], | |
| dtype=torch.float32).to( | |
| out_bbox.device) | |
| return b | |
| # from https://deepdrive.pl/ | |
| def get_output(img, prob, boxes, classes=['Litter'], stat_text=None): | |
| # colors for visualization | |
| STANDARD_COLORS = [ | |
| 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', | |
| 'Bisque', 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', | |
| 'AntiqueWhite', 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', | |
| 'Crimson', 'Cyan', 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', | |
| 'DarkKhaki', 'DarkOrange', 'DarkOrchid', 'DarkSalmon', | |
| 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', | |
| 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', | |
| 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', | |
| 'GoldenRod', 'Salmon', 'Tan', 'HoneyDew', 'HotPink', | |
| 'IndianRed', 'Ivory', 'Khaki', 'Lavender', 'LavenderBlush', | |
| 'LawnGreen', 'LemonChiffon', 'LightBlue', | |
| 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', | |
| 'LightGrey', 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', | |
| 'LightSkyBlue', 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', | |
| 'LightYellow', 'Lime', 'LimeGreen', 'Linen', 'Magenta', | |
| 'MediumAquaMarine', 'MediumOrchid', 'MediumPurple', | |
| 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', | |
| 'MediumTurquoise', 'MediumVioletRed', 'MintCream', | |
| 'MistyRose', 'Moccasin', 'NavajoWhite', 'OldLace', 'Olive', | |
| 'OliveDrab', 'Orange', 'OrangeRed', 'Orchid', 'PaleGoldenRod', | |
| 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', 'PapayaWhip', | |
| 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', | |
| 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', | |
| 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', | |
| 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', | |
| 'GreenYellow', 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', | |
| 'Wheat', 'White', 'WhiteSmoke', 'Yellow', 'YellowGreen' | |
| ] | |
| palette = [ImageColor.getrgb(_) for _ in STANDARD_COLORS] | |
| for p, (x0, y0, x1, y1) in zip(prob, boxes.tolist()): | |
| cl = int(p[1] - 1) | |
| color = palette[cl] | |
| start_p, end_p = (int(x0), int(y0)), (int(x1), int(y1)) | |
| cv2.rectangle(img, start_p, end_p, color, 2) | |
| text = "%s %.1f%%" % (classes[cl], p[0]*100) | |
| cv2.putText(img, text, start_p, cv2.FONT_HERSHEY_SIMPLEX, 1, | |
| (0, 0, 0), 10) | |
| cv2.putText(img, text, start_p, cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) | |
| if stat_text is not None: | |
| cv2.putText(img, stat_text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, | |
| (0, 0, 0), 10) | |
| cv2.putText(img, stat_text, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, | |
| (255, 255, 255), 3) | |
| return img | |
| # from https://deepdrive.pl/ | |
| def save_frames(args, num_iter=45913): | |
| if not os.path.exists(args.save): | |
| os.makedirs(args.save) | |
| cap = cv2.VideoCapture(args.img) | |
| counter = 0 | |
| pbar = tqdm(total=num_iter+1) | |
| num_classes = len(args.classes) | |
| model_name = args.checkpoint.split('-')[-1].split('/')[0] | |
| model = set_model(model_name, num_classes, args.checkpoint, args.device) | |
| model.eval() | |
| model.to(args.device) | |
| while(cap.isOpened()): | |
| ret, img = cap.read() | |
| if img is None: | |
| print("END") | |
| break | |
| # scale + BGR to RGB | |
| inference_size = (768, 768) | |
| scaled_img = cv2.resize(img[:, :, ::-1], inference_size) | |
| transform = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # mean-std normalize the input image (batch-size: 1) | |
| img_tens = transform(scaled_img).unsqueeze(0).to(args.device) | |
| # Inference | |
| t0 = time.time() | |
| with torch.no_grad(): | |
| # propagate through the model | |
| output = model(img_tens) | |
| t1 = time.time() | |
| # keep only predictions above set confidence | |
| bboxes_keep = output[0, output[0, :, 4] > args.prob_threshold] | |
| probas = bboxes_keep[:, 4:] | |
| # convert boxes to image scales | |
| bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], | |
| (img.shape[1], img.shape[0]), | |
| inference_size) | |
| txt = "Detect-waste %s Threshold=%.2f " \ | |
| "Inference %dx%d GPU: %s Inference time %.3fs" % \ | |
| (model_name, args.prob_threshold, inference_size[0], | |
| inference_size[1], torch.cuda.get_device_name(0), | |
| t1 - t0) | |
| result = get_output(img, probas, bboxes_scaled, | |
| args.classes, txt) | |
| cv2.imwrite(os.path.join(args.save, 'img%08d.jpg' % counter), result) | |
| counter += 1 | |
| pbar.update(1) | |
| del img | |
| del img_tens | |
| del result | |
| cap.release() | |
| def plot_results(pil_img, prob, boxes, classes=['Litter'], | |
| save_path=None, colors=None): | |
| plt.figure(figsize=(16, 10)) | |
| plt.imshow(pil_img) | |
| ax = plt.gca() | |
| if colors is None: | |
| # colors for visualization | |
| colors = 100 * [ | |
| [0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
| [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
| for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes, colors): | |
| ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, | |
| fill=False, color=c, linewidth=3)) | |
| cl = int(p[1]) | |
| text = f'{classes[cl]}: {p[0]:0.2f}' | |
| ax.text(xmin, ymin, text, fontsize=15, | |
| bbox=dict(facecolor='yellow', alpha=0.5)) | |
| plt.axis('off') | |
| if save_path is not None: | |
| plt.savefig(save_path, bbox_inches='tight', | |
| transparent=True, pad_inches=0) | |
| plt.close() | |
| print(f'Image saved at {save_path}') | |
| else: | |
| plt.show() | |
| def set_model(model_type, num_classes, checkpoint_path, device): | |
| # create model | |
| model = create_model( | |
| model_type, | |
| bench_task='predict', | |
| num_classes=num_classes, | |
| pretrained=False, | |
| redundant_bias=True, | |
| checkpoint_path=checkpoint_path | |
| ) | |
| param_count = sum([m.numel() for m in model.parameters()]) | |
| print('Model %s created, param count: %d' % (model_type, param_count)) | |
| model = model.to(device) | |
| return model | |
| def main(args): | |
| # prepare model for evaluation | |
| torch.set_grad_enabled(False) | |
| num_classes = len(args.classes) | |
| model_name = args.checkpoint.split('-')[-1].split('/')[0] | |
| model = set_model(model_name, num_classes, args.checkpoint, args.device) | |
| model.eval() | |
| # get image | |
| if args.img.startswith('https'): | |
| im = Image.open(requests.get(args.img, stream=True).raw).convert('RGB') | |
| else: | |
| im = Image.open(args.img).convert('RGB') | |
| # mean-std normalize the input image (batch-size: 1) | |
| img = get_transforms(im) | |
| # propagate through the model | |
| outputs = model(img.to(args.device)) | |
| # keep only predictions above set confidence | |
| bboxes_keep = outputs[0, outputs[0, :, 4] > args.prob_threshold] | |
| probas = bboxes_keep[:, 4:] | |
| # convert boxes to image scales | |
| bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, | |
| tuple(img.size()[2:])) | |
| # plot and save demo image | |
| plot_results(im, probas, bboxes_scaled.tolist(), args.classes, args.save) | |
| if __name__ == '__main__': | |
| parser = get_args_parser() | |
| args = parser.parse_args() | |
| if args.video: | |
| save_frames(args) | |
| else: | |
| main(args) | |