| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
|
|
| import itertools |
| import os |
| import random |
|
|
| import numpy as np |
| import cv2 |
| import torch |
| import torch.nn as nn |
| import torch.utils.data.distributed |
| from zoedepth.utils.easydict import EasyDict as edict |
| from PIL import Image, ImageOps |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import transforms |
|
|
| from zoedepth.utils.config import change_dataset |
|
|
| from .ddad import get_ddad_loader |
| from .diml_indoor_test import get_diml_indoor_loader |
| from .diml_outdoor_test import get_diml_outdoor_loader |
| from .diode import get_diode_loader |
| from .hypersim import get_hypersim_loader |
| from .ibims import get_ibims_loader |
| from .sun_rgbd_loader import get_sunrgbd_loader |
| from .vkitti import get_vkitti_loader |
| from .vkitti2 import get_vkitti2_loader |
|
|
| from .preprocess import CropParams, get_white_border, get_black_border |
|
|
|
|
| def _is_pil_image(img): |
| return isinstance(img, Image.Image) |
|
|
|
|
| def _is_numpy_image(img): |
| return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) |
|
|
|
|
| def preprocessing_transforms(mode, **kwargs): |
| return transforms.Compose([ |
| ToTensor(mode=mode, **kwargs) |
| ]) |
|
|
|
|
| class DepthDataLoader(object): |
| def __init__(self, config, mode, device='cpu', transform=None, **kwargs): |
| """ |
| Data loader for depth datasets |
| |
| Args: |
| config (dict): Config dictionary. Refer to utils/config.py |
| mode (str): "train" or "online_eval" |
| device (str, optional): Device to load the data on. Defaults to 'cpu'. |
| transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None. |
| """ |
|
|
| self.config = config |
|
|
| if config.dataset == 'ibims': |
| self.data = get_ibims_loader(config, batch_size=1, num_workers=1) |
| return |
|
|
| if config.dataset == 'sunrgbd': |
| self.data = get_sunrgbd_loader( |
| data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1) |
| return |
|
|
| if config.dataset == 'diml_indoor': |
| self.data = get_diml_indoor_loader( |
| data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1) |
| return |
|
|
| if config.dataset == 'diml_outdoor': |
| self.data = get_diml_outdoor_loader( |
| data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1) |
| return |
|
|
| if "diode" in config.dataset: |
| self.data = get_diode_loader( |
| config[config.dataset+"_root"], batch_size=1, num_workers=1) |
| return |
|
|
| if config.dataset == 'hypersim_test': |
| self.data = get_hypersim_loader( |
| config.hypersim_test_root, batch_size=1, num_workers=1) |
| return |
|
|
| if config.dataset == 'vkitti': |
| self.data = get_vkitti_loader( |
| config.vkitti_root, batch_size=1, num_workers=1) |
| return |
|
|
| if config.dataset == 'vkitti2': |
| self.data = get_vkitti2_loader( |
| config.vkitti2_root, batch_size=1, num_workers=1) |
| return |
|
|
| if config.dataset == 'ddad': |
| self.data = get_ddad_loader(config.ddad_root, resize_shape=( |
| 352, 1216), batch_size=1, num_workers=1) |
| return |
|
|
| img_size = self.config.get("img_size", None) |
| img_size = img_size if self.config.get( |
| "do_input_resize", False) else None |
|
|
| if transform is None: |
| transform = preprocessing_transforms(mode, size=img_size) |
|
|
| if mode == 'train': |
|
|
| Dataset = DataLoadPreprocess |
| self.training_samples = Dataset( |
| config, mode, transform=transform, device=device) |
|
|
| if config.distributed: |
| self.train_sampler = torch.utils.data.distributed.DistributedSampler( |
| self.training_samples) |
| else: |
| self.train_sampler = None |
|
|
| self.data = DataLoader(self.training_samples, |
| batch_size=config.batch_size, |
| shuffle=(self.train_sampler is None), |
| num_workers=config.workers, |
| pin_memory=True, |
| persistent_workers=True, |
| |
| sampler=self.train_sampler) |
|
|
| elif mode == 'online_eval': |
| self.testing_samples = DataLoadPreprocess( |
| config, mode, transform=transform) |
| if config.distributed: |
| |
| self.eval_sampler = None |
| else: |
| self.eval_sampler = None |
| self.data = DataLoader(self.testing_samples, 1, |
| shuffle=kwargs.get("shuffle_test", False), |
| num_workers=1, |
| pin_memory=False, |
| sampler=self.eval_sampler) |
|
|
| elif mode == 'test': |
| self.testing_samples = DataLoadPreprocess( |
| config, mode, transform=transform) |
| self.data = DataLoader(self.testing_samples, |
| 1, shuffle=False, num_workers=1) |
|
|
| else: |
| print( |
| 'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode)) |
|
|
|
|
| def repetitive_roundrobin(*iterables): |
| """ |
| cycles through iterables but sample wise |
| first yield first sample from first iterable then first sample from second iterable and so on |
| then second sample from first iterable then second sample from second iterable and so on |
| |
| If one iterable is shorter than the others, it is repeated until all iterables are exhausted |
| repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E |
| """ |
| |
| iterables_ = [iter(it) for it in iterables] |
| exhausted = [False] * len(iterables) |
| while not all(exhausted): |
| for i, it in enumerate(iterables_): |
| try: |
| yield next(it) |
| except StopIteration: |
| exhausted[i] = True |
| iterables_[i] = itertools.cycle(iterables[i]) |
| |
| yield next(iterables_[i]) |
|
|
|
|
| class RepetitiveRoundRobinDataLoader(object): |
| def __init__(self, *dataloaders): |
| self.dataloaders = dataloaders |
|
|
| def __iter__(self): |
| return repetitive_roundrobin(*self.dataloaders) |
|
|
| def __len__(self): |
| |
| return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1) |
|
|
|
|
| class MixedNYUKITTI(object): |
| def __init__(self, config, mode, device='cpu', **kwargs): |
| config = edict(config) |
| config.workers = config.workers // 2 |
| self.config = config |
| nyu_conf = change_dataset(edict(config), 'nyu') |
| kitti_conf = change_dataset(edict(config), 'kitti') |
|
|
| |
| self.config = config = nyu_conf |
| img_size = self.config.get("img_size", None) |
| img_size = img_size if self.config.get( |
| "do_input_resize", False) else None |
| if mode == 'train': |
| nyu_loader = DepthDataLoader( |
| nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data |
| kitti_loader = DepthDataLoader( |
| kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data |
| |
| self.data = RepetitiveRoundRobinDataLoader( |
| nyu_loader, kitti_loader) |
| else: |
| self.data = DepthDataLoader(nyu_conf, mode, device=device).data |
|
|
|
|
| def remove_leading_slash(s): |
| if s[0] == '/' or s[0] == '\\': |
| return s[1:] |
| return s |
|
|
|
|
| class CachedReader: |
| def __init__(self, shared_dict=None): |
| if shared_dict: |
| self._cache = shared_dict |
| else: |
| self._cache = {} |
|
|
| def open(self, fpath): |
| im = self._cache.get(fpath, None) |
| if im is None: |
| im = self._cache[fpath] = Image.open(fpath) |
| return im |
|
|
|
|
| class ImReader: |
| def __init__(self): |
| pass |
|
|
| |
| def open(self, fpath): |
| return Image.open(fpath) |
|
|
|
|
| class DataLoadPreprocess(Dataset): |
| def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs): |
| self.config = config |
| if mode == 'online_eval': |
| with open(config.filenames_file_eval, 'r') as f: |
| self.filenames = f.readlines() |
| else: |
| with open(config.filenames_file, 'r') as f: |
| self.filenames = f.readlines() |
|
|
| self.mode = mode |
| self.transform = transform |
| self.to_tensor = ToTensor(mode) |
| self.is_for_online_eval = is_for_online_eval |
| if config.use_shared_dict: |
| self.reader = CachedReader(config.shared_dict) |
| else: |
| self.reader = ImReader() |
|
|
| def postprocess(self, sample): |
| return sample |
|
|
| def __getitem__(self, idx): |
| sample_path = self.filenames[idx] |
| focal = float(sample_path.split()[2]) |
| sample = {} |
|
|
| if self.mode == 'train': |
| if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5: |
| image_path = os.path.join( |
| self.config.data_path, remove_leading_slash(sample_path.split()[3])) |
| depth_path = os.path.join( |
| self.config.gt_path, remove_leading_slash(sample_path.split()[4])) |
| else: |
| image_path = os.path.join( |
| self.config.data_path, remove_leading_slash(sample_path.split()[0])) |
| depth_path = os.path.join( |
| self.config.gt_path, remove_leading_slash(sample_path.split()[1])) |
|
|
| image = self.reader.open(image_path) |
| depth_gt = self.reader.open(depth_path) |
| w, h = image.size |
|
|
| if self.config.do_kb_crop: |
| height = image.height |
| width = image.width |
| top_margin = int(height - 352) |
| left_margin = int((width - 1216) / 2) |
| depth_gt = depth_gt.crop( |
| (left_margin, top_margin, left_margin + 1216, top_margin + 352)) |
| image = image.crop( |
| (left_margin, top_margin, left_margin + 1216, top_margin + 352)) |
|
|
| |
| |
| if self.config.dataset == 'nyu' and self.config.avoid_boundary: |
| |
| |
| |
| crop_params = get_white_border(np.array(image, dtype=np.uint8)) |
| image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) |
| depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) |
|
|
| |
| image = np.array(image) |
| image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect') |
| image = Image.fromarray(image) |
|
|
| depth_gt = np.array(depth_gt) |
| depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0) |
| depth_gt = Image.fromarray(depth_gt) |
|
|
|
|
| if self.config.do_random_rotate and (self.config.aug): |
| random_angle = (random.random() - 0.5) * 2 * self.config.degree |
| image = self.rotate_image(image, random_angle) |
| depth_gt = self.rotate_image( |
| depth_gt, random_angle, flag=Image.NEAREST) |
|
|
| image = np.asarray(image, dtype=np.float32) / 255.0 |
| depth_gt = np.asarray(depth_gt, dtype=np.float32) |
| depth_gt = np.expand_dims(depth_gt, axis=2) |
|
|
| if self.config.dataset == 'nyu': |
| depth_gt = depth_gt / 1000.0 |
| else: |
| depth_gt = depth_gt / 256.0 |
|
|
| if self.config.aug and (self.config.random_crop): |
| image, depth_gt = self.random_crop( |
| image, depth_gt, self.config.input_height, self.config.input_width) |
| |
| if self.config.aug and self.config.random_translate: |
| |
| image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation) |
|
|
| image, depth_gt = self.train_preprocess(image, depth_gt) |
| mask = np.logical_and(depth_gt > self.config.min_depth, |
| depth_gt < self.config.max_depth).squeeze()[None, ...] |
| sample = {'image': image, 'depth': depth_gt, 'focal': focal, |
| 'mask': mask, **sample} |
|
|
| else: |
| if self.mode == 'online_eval': |
| data_path = self.config.data_path_eval |
| else: |
| data_path = self.config.data_path |
|
|
| image_path = os.path.join( |
| data_path, remove_leading_slash(sample_path.split()[0])) |
| image = np.asarray(self.reader.open(image_path), |
| dtype=np.float32) / 255.0 |
|
|
| if self.mode == 'online_eval': |
| gt_path = self.config.gt_path_eval |
| depth_path = os.path.join( |
| gt_path, remove_leading_slash(sample_path.split()[1])) |
| has_valid_depth = False |
| try: |
| depth_gt = self.reader.open(depth_path) |
| has_valid_depth = True |
| except IOError: |
| depth_gt = False |
| |
|
|
| if has_valid_depth: |
| depth_gt = np.asarray(depth_gt, dtype=np.float32) |
| depth_gt = np.expand_dims(depth_gt, axis=2) |
| if self.config.dataset == 'nyu': |
| depth_gt = depth_gt / 1000.0 |
| else: |
| depth_gt = depth_gt / 256.0 |
|
|
| mask = np.logical_and( |
| depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...] |
| else: |
| mask = False |
|
|
| if self.config.do_kb_crop: |
| height = image.shape[0] |
| width = image.shape[1] |
| top_margin = int(height - 352) |
| left_margin = int((width - 1216) / 2) |
| image = image[top_margin:top_margin + 352, |
| left_margin:left_margin + 1216, :] |
| if self.mode == 'online_eval' and has_valid_depth: |
| depth_gt = depth_gt[top_margin:top_margin + |
| 352, left_margin:left_margin + 1216, :] |
|
|
| if self.mode == 'online_eval': |
| sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth, |
| 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1], |
| 'mask': mask} |
| else: |
| sample = {'image': image, 'focal': focal} |
|
|
| if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']): |
| mask = np.logical_and(depth_gt > self.config.min_depth, |
| depth_gt < self.config.max_depth).squeeze()[None, ...] |
| sample['mask'] = mask |
|
|
| if self.transform: |
| sample = self.transform(sample) |
|
|
| sample = self.postprocess(sample) |
| sample['dataset'] = self.config.dataset |
| sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]} |
|
|
| return sample |
|
|
| def rotate_image(self, image, angle, flag=Image.BILINEAR): |
| result = image.rotate(angle, resample=flag) |
| return result |
|
|
| def random_crop(self, img, depth, height, width): |
| assert img.shape[0] >= height |
| assert img.shape[1] >= width |
| assert img.shape[0] == depth.shape[0] |
| assert img.shape[1] == depth.shape[1] |
| x = random.randint(0, img.shape[1] - width) |
| y = random.randint(0, img.shape[0] - height) |
| img = img[y:y + height, x:x + width, :] |
| depth = depth[y:y + height, x:x + width, :] |
|
|
| return img, depth |
| |
| def random_translate(self, img, depth, max_t=20): |
| assert img.shape[0] == depth.shape[0] |
| assert img.shape[1] == depth.shape[1] |
| p = self.config.translate_prob |
| do_translate = random.random() |
| if do_translate > p: |
| return img, depth |
| x = random.randint(-max_t, max_t) |
| y = random.randint(-max_t, max_t) |
| M = np.float32([[1, 0, x], [0, 1, y]]) |
| |
| img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) |
| depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0])) |
| depth = depth.squeeze()[..., None] |
| |
| return img, depth |
|
|
| def train_preprocess(self, image, depth_gt): |
| if self.config.aug: |
| |
| do_flip = random.random() |
| if do_flip > 0.5: |
| image = (image[:, ::-1, :]).copy() |
| depth_gt = (depth_gt[:, ::-1, :]).copy() |
|
|
| |
| do_augment = random.random() |
| if do_augment > 0.5: |
| image = self.augment_image(image) |
|
|
| return image, depth_gt |
|
|
| def augment_image(self, image): |
| |
| gamma = random.uniform(0.9, 1.1) |
| image_aug = image ** gamma |
|
|
| |
| if self.config.dataset == 'nyu': |
| brightness = random.uniform(0.75, 1.25) |
| else: |
| brightness = random.uniform(0.9, 1.1) |
| image_aug = image_aug * brightness |
|
|
| |
| colors = np.random.uniform(0.9, 1.1, size=3) |
| white = np.ones((image.shape[0], image.shape[1])) |
| color_image = np.stack([white * colors[i] for i in range(3)], axis=2) |
| image_aug *= color_image |
| image_aug = np.clip(image_aug, 0, 1) |
|
|
| return image_aug |
|
|
| def __len__(self): |
| return len(self.filenames) |
|
|
|
|
| class ToTensor(object): |
| def __init__(self, mode, do_normalize=False, size=None): |
| self.mode = mode |
| self.normalize = transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity() |
| self.size = size |
| if size is not None: |
| self.resize = transforms.Resize(size=size) |
| else: |
| self.resize = nn.Identity() |
|
|
| def __call__(self, sample): |
| image, focal = sample['image'], sample['focal'] |
| image = self.to_tensor(image) |
| image = self.normalize(image) |
| image = self.resize(image) |
|
|
| if self.mode == 'test': |
| return {'image': image, 'focal': focal} |
|
|
| depth = sample['depth'] |
| if self.mode == 'train': |
| depth = self.to_tensor(depth) |
| return {**sample, 'image': image, 'depth': depth, 'focal': focal} |
| else: |
| has_valid_depth = sample['has_valid_depth'] |
| image = self.resize(image) |
| return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth, |
| 'image_path': sample['image_path'], 'depth_path': sample['depth_path']} |
|
|
| def to_tensor(self, pic): |
| if not (_is_pil_image(pic) or _is_numpy_image(pic)): |
| raise TypeError( |
| 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) |
|
|
| if isinstance(pic, np.ndarray): |
| img = torch.from_numpy(pic.transpose((2, 0, 1))) |
| return img |
|
|
| |
| if pic.mode == 'I': |
| img = torch.from_numpy(np.array(pic, np.int32, copy=False)) |
| elif pic.mode == 'I;16': |
| img = torch.from_numpy(np.array(pic, np.int16, copy=False)) |
| else: |
| img = torch.ByteTensor( |
| torch.ByteStorage.from_buffer(pic.tobytes())) |
| |
| if pic.mode == 'YCbCr': |
| nchannel = 3 |
| elif pic.mode == 'I;16': |
| nchannel = 1 |
| else: |
| nchannel = len(pic.mode) |
| img = img.view(pic.size[1], pic.size[0], nchannel) |
|
|
| img = img.transpose(0, 1).transpose(0, 2).contiguous() |
| if isinstance(img, torch.ByteTensor): |
| return img.float() |
| else: |
| return img |
|
|