| | import os |
| | import sys |
| | import random |
| | import numpy as np |
| | from tqdm import tqdm, trange |
| | from PIL import Image, ImageOps, ImageFilter |
| |
|
| | import torch |
| | import torch.utils.data as data |
| | import torchvision.transforms as transform |
| |
|
| | from datasets.base import BaseDataset |
| |
|
| | class CitySegmentation(BaseDataset): |
| | NUM_CLASS = 19 |
| | def __init__(self, root, split='val', mode='testval', transform=None, target_transform=None, **kwargs): |
| | super(CitySegmentation, self).__init__( |
| | root, split, mode, transform, target_transform, **kwargs) |
| | self.images, self.mask_paths = get_city_pairs(self.root, self.split) |
| | assert (len(self.images) == len(self.mask_paths)) |
| | if len(self.images) == 0: |
| | raise RuntimeError("Found 0 images in subfolders of: \ |
| | " + self.root + "\n") |
| | self._indices = np.array(range(-1, 19)) |
| | self._classes = np.array([0, 7, 8, 11, 12, 13, 17, 19, 20, 21, 22, |
| | 23, 24, 25, 26, 27, 28, 31, 32, 33]) |
| | self._key = np.array([-1, -1, -1, -1, -1, -1, |
| | -1, -1, 0, 1, -1, -1, |
| | 2, 3, 4, -1, -1, -1, |
| | 5, -1, 6, 7, 8, 9, |
| | 10, 11, 12, 13, 14, 15, |
| | -1, -1, 16, 17, 18]) |
| | self._mapping = np.array(range(-1, len(self._key)-1)).astype('int32') |
| |
|
| | def _class_to_index(self, mask): |
| | |
| | values = np.unique(mask) |
| | for i in range(len(values)): |
| | assert(values[i] in self._mapping) |
| | index = np.digitize(mask.ravel(), self._mapping, right=True) |
| | return self._key[index].reshape(mask.shape) |
| |
|
| | def __getitem__(self, index): |
| | img = Image.open(self.images[index]).convert('RGB') |
| | mask = Image.open(self.mask_paths[index]) |
| | if self.mode == 'testval': |
| | img, mask = self._testval_transform(img, mask) |
| | elif self.mode == 'val': |
| | img, mask = self._val_transform(img, mask) |
| | elif self.mode == 'train': |
| | img, mask = self._train_transform(img, mask) |
| |
|
| | if self.transform is not None: |
| | img = self.transform(img) |
| | if self.target_transform is not None: |
| | mask = self.target_transform(mask) |
| | return img, mask |
| |
|
| | def _mask_transform(self, mask): |
| | target = self._class_to_index(np.array(mask).astype('int32')) |
| | return torch.from_numpy(target).long() |
| |
|
| | def __len__(self): |
| | return len(self.images) |
| |
|
| |
|
| | def get_city_pairs(folder, split='val'): |
| | def get_path_pairs(img_folder, mask_folder): |
| | img_paths = [] |
| | mask_paths = [] |
| | for root, directories, files in os.walk(img_folder): |
| | for filename in files: |
| | if filename.endswith(".png"): |
| | imgpath = os.path.join(root, filename) |
| | foldername = os.path.basename(os.path.dirname(imgpath)) |
| | maskname = filename.replace('leftImg8bit','gtFine_labelIds') |
| | maskpath = os.path.join(mask_folder, foldername, maskname) |
| | if os.path.isfile(imgpath) and os.path.isfile(maskpath): |
| | img_paths.append(imgpath) |
| | mask_paths.append(maskpath) |
| | else: |
| | print('cannot find the mask or image:', imgpath, maskpath) |
| | print('Found {} images in the folder {}'.format(len(img_paths), img_folder)) |
| | return img_paths, mask_paths |
| |
|
| | img_folder = os.path.join(folder, 'leftImg8bit/' + split) |
| | mask_folder = os.path.join(folder, 'gtFine/'+ split) |
| | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) |
| | return img_paths, mask_paths |
| |
|