Spaces:
Running
Running
| from __future__ import print_function, division | |
| import os | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from .mypath_pascal import Path | |
| class VOCSegmentation(Dataset): | |
| """ | |
| Pascal dataset | |
| """ | |
| def __init__(self, | |
| base_dir=Path.db_root_dir('pascal'), | |
| split='train', | |
| transform=None | |
| ): | |
| """ | |
| :param base_dir: path to PASCAL dataset directory | |
| :param split: train/val | |
| :param transform: transform to apply | |
| """ | |
| super(VOCSegmentation).__init__() | |
| self._base_dir = base_dir | |
| self._image_dir = os.path.join(self._base_dir, 'JPEGImages') | |
| self._cat_dir = os.path.join(self._base_dir, 'SegmentationPart') | |
| if isinstance(split, str): | |
| self.split = [split] | |
| else: | |
| split.sort() | |
| self.split = split | |
| self.transform = transform | |
| _splits_dir = os.path.join(self._base_dir, 'list') | |
| self.im_ids = [] | |
| self.images = [] | |
| self.categories = [] | |
| for splt in self.split: | |
| with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f: | |
| lines = f.read().splitlines() | |
| for ii, line in enumerate(lines): | |
| _image = os.path.join(self._image_dir, line+'.jpg' ) | |
| _cat = os.path.join(self._cat_dir, line +'.png') | |
| # print(self._image_dir,_image) | |
| assert os.path.isfile(_image) | |
| # print(_cat) | |
| assert os.path.isfile(_cat) | |
| self.im_ids.append(line) | |
| self.images.append(_image) | |
| self.categories.append(_cat) | |
| assert (len(self.images) == len(self.categories)) | |
| # Display stats | |
| print('Number of images in {}: {:d}'.format(split, len(self.images))) | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, index): | |
| _img, _target= self._make_img_gt_point_pair(index) | |
| sample = {'image': _img, 'label': _target} | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| return sample | |
| def _make_img_gt_point_pair(self, index): | |
| # Read Image and Target | |
| # _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32) | |
| # _target = np.array(Image.open(self.categories[index])).astype(np.float32) | |
| _img = Image.open(self.images[index]).convert('RGB') # return is RGB pic | |
| _target = Image.open(self.categories[index]) | |
| return _img, _target | |
| def __str__(self): | |
| return 'PASCAL(split=' + str(self.split) + ')' | |
| class test_segmentation(VOCSegmentation): | |
| def __init__(self,base_dir=Path.db_root_dir('pascal'), | |
| split='train', | |
| transform=None, | |
| flip=True): | |
| super(test_segmentation, self).__init__(base_dir=base_dir,split=split,transform=transform) | |
| self._flip_flag = flip | |
| def __getitem__(self, index): | |
| _img, _target= self._make_img_gt_point_pair(index) | |
| sample = {'image': _img, 'label': _target} | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| return sample | |