Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # Deformable DETR | |
| # Copyright (c) 2020 SenseTime. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| # Modified from torchvision | |
| # ------------------------------------------------------------------------ | |
| """ | |
| Copy-Paste from torchvision, but add utility of caching images on memory | |
| """ | |
| from torchvision.datasets.vision import VisionDataset | |
| from PIL import Image | |
| import os | |
| import os.path | |
| import tqdm | |
| from io import BytesIO | |
| class CocoDetection(VisionDataset): | |
| """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset. | |
| Args: | |
| root (string): Root directory where images are downloaded to. | |
| annFile (string): Path to json annotation file. | |
| transform (callable, optional): A function/transform that takes in an PIL image | |
| and returns a transformed version. E.g, ``transforms.ToTensor`` | |
| target_transform (callable, optional): A function/transform that takes in the | |
| target and transforms it. | |
| transforms (callable, optional): A function/transform that takes input sample and its target as entry | |
| and returns a transformed version. | |
| """ | |
| def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None, | |
| cache_mode=False, local_rank=0, local_size=1): | |
| super(CocoDetection, self).__init__(root, transforms, transform, target_transform) | |
| from pycocotools.coco import COCO | |
| self.coco = COCO(annFile) | |
| self.ids = list(sorted(self.coco.imgs.keys())) | |
| self.cache_mode = cache_mode | |
| self.local_rank = local_rank | |
| self.local_size = local_size | |
| if cache_mode: | |
| self.cache = {} | |
| self.cache_images() | |
| def cache_images(self): | |
| self.cache = {} | |
| for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids): | |
| if index % self.local_size != self.local_rank: | |
| continue | |
| path = self.coco.loadImgs(img_id)[0]['file_name'] | |
| with open(os.path.join(self.root, path), 'rb') as f: | |
| self.cache[path] = f.read() | |
| def get_image(self, path): | |
| if self.cache_mode: | |
| if path not in self.cache.keys(): | |
| with open(os.path.join(self.root, path), 'rb') as f: | |
| self.cache[path] = f.read() | |
| return Image.open(BytesIO(self.cache[path])).convert('RGB') | |
| return Image.open(os.path.join(self.root, path)).convert('RGB') | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. | |
| """ | |
| coco = self.coco | |
| img_id = self.ids[index] | |
| ann_ids = coco.getAnnIds(imgIds=img_id) | |
| target = coco.loadAnns(ann_ids) | |
| path = coco.loadImgs(img_id)[0]['file_name'] | |
| img = self.get_image(path) | |
| if self.transforms is not None: | |
| img, target = self.transforms(img, target) | |
| return img, target | |
| def __len__(self): | |
| return len(self.ids) | |