| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import io |
| import numpy as np |
| import os |
| import random |
| import tarfile |
| import torch |
| from PIL import Image |
| import torchvision.transforms.functional as TF |
| from torch.utils.data import Dataset, get_worker_info |
| from torchvision.transforms import InterpolationMode, Resize, ColorJitter |
| from .base_depth_dataset import DatasetMode |
|
|
|
|
| class BaseNormalsDataset(Dataset): |
| def __init__( |
| self, |
| mode: DatasetMode, |
| filename_ls_path: str, |
| dataset_dir: str, |
| disp_name: str, |
| augmentation_args: dict = None, |
| resize_to_hw=None, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| self.mode = mode |
| |
| self.filename_ls_path = filename_ls_path |
| self.dataset_dir = dataset_dir |
| assert os.path.exists( |
| self.dataset_dir |
| ), f"Dataset does not exist at: {self.dataset_dir}" |
| self.disp_name = disp_name |
|
|
| |
| self.augm_args = augmentation_args |
| self.resize_to_hw = resize_to_hw |
|
|
| |
| with open(self.filename_ls_path, "r") as f: |
| self.filenames = [s.split() for s in f.readlines()] |
|
|
| |
| self.tar_obj = None |
| self.is_tar = ( |
| True |
| if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) |
| else False |
| ) |
|
|
| if self.is_tar: |
| self.tar_obj = tarfile.open(self.dataset_dir) |
|
|
| def __len__(self): |
| return len(self.filenames) |
|
|
| def __getitem__(self, index): |
| rasters, other = self._get_data_item(index) |
| if DatasetMode.TRAIN == self.mode: |
| rasters = self._training_preprocess(rasters) |
| |
| outputs = rasters |
| outputs.update(other) |
| return outputs |
|
|
| def _get_data_item(self, index): |
| rgb_rel_path, normals_rel_path = self._get_data_path(index=index) |
|
|
| rasters = {} |
|
|
| |
| rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) |
|
|
| |
| if DatasetMode.RGB_ONLY != self.mode: |
| normals_data = self._load_normals_data(normals_rel_path=normals_rel_path) |
| rasters.update(normals_data) |
|
|
| other = {"index": index, "rgb_relative_path": rgb_rel_path} |
|
|
| return rasters, other |
|
|
| def _load_rgb_data(self, rgb_rel_path): |
| |
| rgb = self._read_rgb_file(rgb_rel_path) |
| rgb_norm = rgb / 255.0 * 2.0 - 1.0 |
|
|
| outputs = { |
| "rgb_int": torch.from_numpy(rgb).int(), |
| "rgb_norm": torch.from_numpy(rgb_norm).float(), |
| } |
| return outputs |
|
|
| def _load_normals_data(self, normals_rel_path): |
| outputs = {} |
| normals = torch.from_numpy( |
| self._read_normals_file(normals_rel_path) |
| ).float() |
| outputs["normals"] = normals |
|
|
| return outputs |
|
|
| def _get_data_path(self, index): |
| filename_line = self.filenames[index] |
|
|
| |
| rgb_rel_path = filename_line[0] |
| normals_rel_path = filename_line[1] |
|
|
| return rgb_rel_path, normals_rel_path |
|
|
| def _read_image(self, img_rel_path) -> np.ndarray: |
| if self.is_tar: |
| if self.tar_obj is None: |
| self.tar_obj = tarfile.open(self.dataset_dir) |
| image_to_read = self.tar_obj.extractfile("./" + img_rel_path) |
| image_to_read = image_to_read.read() |
| image_to_read = io.BytesIO(image_to_read) |
| else: |
| image_to_read = os.path.join(self.dataset_dir, img_rel_path) |
| image = Image.open(image_to_read) |
| image = np.asarray(image) |
| return image |
|
|
| def _read_rgb_file(self, rel_path) -> np.ndarray: |
| rgb = self._read_image(rel_path) |
| rgb = np.transpose(rgb, (2, 0, 1)).astype(int) |
| return rgb |
|
|
| def _read_normals_file(self, rel_path): |
| if self.is_tar: |
| if self.tar_obj is None: |
| self.tar_obj = tarfile.open(self.dataset_dir) |
| |
| normal = self.tar_obj.extractfile("./" + rel_path) |
| normal = normal.read() |
| normal = np.load(io.BytesIO(normal)) |
| else: |
| normal_path = os.path.join(self.dataset_dir, rel_path) |
| normal = np.load(normal_path) |
| normal = np.transpose(normal, (2, 0, 1)) |
| return normal |
|
|
| def _training_preprocess(self, rasters): |
| |
| if self.augm_args is not None: |
| rasters = self._augment_data(rasters) |
|
|
| |
| if self.resize_to_hw is not None: |
| resize_transform = Resize( |
| size=self.resize_to_hw, interpolation=InterpolationMode.BILINEAR |
| ) |
| rasters = {k: resize_transform(v) for k, v in rasters.items()} |
|
|
| return rasters |
|
|
| def _augment_data(self, rasters): |
| |
| if random.random() < self.augm_args.lr_flip_p: |
| rasters = {k: v.flip(-1) for k, v in rasters.items()} |
| rasters["normals"][0, :, :] *= -1 |
|
|
| |
| use_gpu = get_worker_info() is None |
| if use_gpu: |
| rasters = {k: v.cuda() for k, v in rasters.items()} |
|
|
| |
| if ( |
| random.random() < self.augm_args.gaussian_blur_p |
| and rasters["rgb_int"].shape[-2] == 768 |
| ): |
| random_rgb_sigma = random.uniform(0.0, self.augm_args.gaussian_blur_sigma) |
| rasters["rgb_int"] = TF.gaussian_blur( |
| rasters["rgb_int"], kernel_size=33, sigma=random_rgb_sigma |
| ).int() |
|
|
| |
| if ( |
| random.random() < self.augm_args.motion_blur_p |
| and rasters["rgb_int"].shape[-2] == 768 |
| ): |
| random_kernel_size = random.choice( |
| [ |
| x |
| for x in range(3, self.augm_args.motion_blur_kernel_size + 1) |
| if x % 2 == 1 |
| ] |
| ) |
| kernel = torch.zeros( |
| random_kernel_size, |
| random_kernel_size, |
| dtype=rasters["rgb_int"].dtype, |
| device=rasters["rgb_int"].device, |
| ) |
| kernel[random_kernel_size // 2, :] = torch.ones(random_kernel_size) |
| kernel = TF.rotate( |
| kernel.unsqueeze(0), |
| random.uniform(0.0, self.augm_args.motion_blur_angle_range), |
| ) |
| kernel = kernel / kernel.sum() |
| channels = rasters["rgb_int"].shape[0] |
| kernel = kernel.expand(channels, 1, random_kernel_size, random_kernel_size) |
| rasters["rgb_int"] = ( |
| torch.conv2d( |
| rasters["rgb_int"].unsqueeze(0).float(), |
| kernel, |
| stride=1, |
| padding=random_kernel_size // 2, |
| groups=channels, |
| ) |
| .squeeze(0) |
| .int() |
| ) |
| |
| if random.random() < self.augm_args.color_jitter_p: |
| color_jitter = ColorJitter( |
| brightness=self.augm_args.jitter_brightness_factor, |
| contrast=self.augm_args.jitter_contrast_factor, |
| saturation=self.augm_args.jitter_saturation_factor, |
| hue=self.augm_args.jitter_hue_factor, |
| ) |
| rgb_int_temp = rasters["rgb_int"].float() / 255.0 |
| rgb_int_temp = color_jitter(rgb_int_temp) |
| rasters["rgb_int"] = (rgb_int_temp * 255.0).int() |
|
|
| |
| rasters["rgb_norm"] = rasters["rgb_int"].float() / 255.0 * 2.0 - 1.0 |
| return rasters |
|
|
| def __del__(self): |
| if hasattr(self, "tar_obj") and self.tar_obj is not None: |
| self.tar_obj.close() |
| self.tar_obj = None |
|
|