Spaces:
Sleeping
Sleeping
| """ Detection dataset | |
| Hacked together by Ross Wightman | |
| """ | |
| import torch.utils.data as data | |
| import numpy as np | |
| import albumentations as A | |
| import torch | |
| from PIL import Image | |
| from .parsers import create_parser | |
| class DetectionDatset(data.Dataset): | |
| """`Object Detection Dataset. Use with parsers for COCO, VOC, and OpenImages. | |
| Args: | |
| parser (string, Parser): | |
| transform (callable, optional): A function/transform that takes in an PIL image | |
| and returns a transformed version. E.g, ``transforms.ToTensor`` | |
| """ | |
| def __init__(self, data_dir, parser=None, parser_kwargs=None, transform=None, transforms=None): | |
| super(DetectionDatset, self).__init__() | |
| parser_kwargs = parser_kwargs or {} | |
| self.data_dir = data_dir | |
| if isinstance(parser, str): | |
| self._parser = create_parser(parser, **parser_kwargs) | |
| else: | |
| assert parser is not None and len(parser.img_ids) | |
| self._parser = parser | |
| self._transform = transform | |
| self._transforms = transforms | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: Tuple (image, annotations (target)). | |
| """ | |
| img_info = self._parser.img_infos[index] | |
| target = dict(img_idx=index, img_size=(img_info['width'], img_info['height'])) | |
| if self._parser.has_labels: | |
| ann = self._parser.get_ann_info(index) | |
| target.update(ann) | |
| img_path = self.data_dir / img_info['file_name'] | |
| img = Image.open(img_path).convert('RGB') | |
| if self.transforms is not None: | |
| img = torch.as_tensor(np.array(img), dtype=torch.uint8) | |
| voc_boxes = [] | |
| for coord in target['bbox']: | |
| xmin = coord[1] | |
| ymin = coord[0] | |
| xmax = coord[3] | |
| ymax = coord[2] | |
| if xmin<1: | |
| xmin = 1 | |
| if ymin<1: | |
| ymin = 1 | |
| if xmax>=img.shape[1]-1: | |
| xmax = img.shape[1]-1 | |
| if ymax>=img.shape[0]-1: | |
| ymax = img.shape[0]-1 | |
| voc_boxes.append([xmin, ymin, xmax, ymax]) | |
| transformed = self.transforms(image=np.array(img), bbox_classes=target['cls'], bboxes=voc_boxes) | |
| img = torch.as_tensor(transformed['image'], dtype=torch.uint8) | |
| target['bbox'] = [] | |
| for coord in transformed['bboxes']: | |
| ymin = int(coord[1]) | |
| xmin = int(coord[0]) | |
| ymax = int(coord[3]) | |
| xmax = int(coord[2]) | |
| target['bbox'].append([ymin, xmin, ymax, xmax]) | |
| target['bbox'] = np.array(target['bbox'], dtype=np.float32) | |
| target['cls'] = np.array(transformed['bbox_classes']) | |
| img = Image.fromarray(np.array(img).astype('uint8'), 'RGB') | |
| target['img_size'] = img.size | |
| if self.transform is not None: | |
| img, target = self.transform(img, target) | |
| return img, target | |
| def __len__(self): | |
| return len(self._parser.img_ids) | |
| def parser(self): | |
| return self._parser | |
| def transform(self): | |
| return self._transform | |
| def transform(self, t): | |
| self._transform = t | |
| def transforms(self): | |
| return self._transforms | |
| def transforms(self, t): | |
| self._transforms = t | |
| class SkipSubset(data.Dataset): | |
| r""" | |
| Subset of a dataset at specified indices. | |
| Arguments: | |
| dataset (Dataset): The whole Dataset | |
| n (int): skip rate (select every nth) | |
| """ | |
| def __init__(self, dataset, n=2): | |
| self.dataset = dataset | |
| assert n >= 1 | |
| self.indices = np.arange(len(dataset))[::n] | |
| def __getitem__(self, idx): | |
| return self.dataset[self.indices[idx]] | |
| def __len__(self): | |
| return len(self.indices) | |
| def parser(self): | |
| return self.dataset.parser | |
| def transform(self): | |
| return self.dataset.transform | |
| def transform(self, t): | |
| self.dataset.transform = t | |
| def transforms(self): | |
| return self.dataset.transforms | |
| def transforms(self, t): | |
| self.dataset.transforms = t | |