| | |
| | |
| |
|
| | import io |
| | import os |
| | import random |
| | import tarfile |
| | from enum import Enum |
| |
|
| | import numpy as np |
| | import cv2 |
| | import torch |
| | from PIL import Image |
| | from torch.utils.data import Dataset |
| | from torchvision.transforms import InterpolationMode, Resize |
| |
|
| |
|
| | class DatasetMode(Enum): |
| | RGB_ONLY = "rgb_only" |
| | EVAL = "evaluate" |
| | TRAIN = "train" |
| |
|
| |
|
| | def read_image_from_tar(tar_obj, img_rel_path): |
| | image = tar_obj.extractfile("./" + img_rel_path) |
| | image = image.read() |
| | image = Image.open(io.BytesIO(image)) |
| |
|
| |
|
| | class BaseDepthDataset(Dataset): |
| | def __init__( |
| | self, |
| | mode: DatasetMode, |
| | filename_ls_path: str, |
| | dataset_dir: str, |
| | disp_name: str, |
| | min_depth, |
| | max_depth, |
| | has_filled_depth, |
| | name_mode, |
| | depth_transform=None, |
| | augmentation_args: dict = None, |
| | resize_to_hw=None, |
| | move_invalid_to_far_plane: bool = True, |
| | rgb_transform=lambda x: x / 255.0 * 2 - 1, |
| | **kwargs, |
| | ) -> None: |
| | super().__init__() |
| | self.mode = mode |
| | |
| | self.filename_ls_path = filename_ls_path |
| | self.dataset_dir = dataset_dir |
| | self.disp_name = disp_name |
| | self.has_filled_depth = has_filled_depth |
| | self.name_mode: DepthFileNameMode = name_mode |
| | self.min_depth = min_depth |
| | self.max_depth = max_depth |
| |
|
| | |
| | self.depth_transform = depth_transform |
| | self.augm_args = augmentation_args |
| | self.resize_to_hw = resize_to_hw |
| | self.rgb_transform = rgb_transform |
| | self.move_invalid_to_far_plane = move_invalid_to_far_plane |
| |
|
| | |
| | with open(self.filename_ls_path, "r") as f: |
| | self.filenames = [ |
| | s.split() for s in f.readlines() |
| | ] |
| |
|
| | |
| | self.tar_obj = None |
| | self.is_tar = ( |
| | True |
| | if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) |
| | else False |
| | ) |
| |
|
| | def __len__(self): |
| | return len(self.filenames) |
| |
|
| | def __getitem__(self, index): |
| | rasters, other = self._get_data_item(index) |
| | if DatasetMode.TRAIN == self.mode: |
| | rasters = self._training_preprocess(rasters) |
| | |
| | outputs = rasters |
| | outputs.update(other) |
| | return outputs |
| |
|
| | def _get_data_item(self, index): |
| | rgb_rel_path, depth_rel_path, filled_rel_path = self._get_data_path(index=index) |
| |
|
| | rasters = {} |
| |
|
| | |
| | rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) |
| |
|
| | |
| | if DatasetMode.RGB_ONLY != self.mode: |
| | |
| | depth_data = self._load_depth_data( |
| | depth_rel_path=depth_rel_path, filled_rel_path=filled_rel_path |
| | ) |
| | rasters.update(depth_data) |
| | |
| | rasters["valid_mask_raw"] = self._get_valid_mask( |
| | rasters["depth_raw_linear"] |
| | ).clone() |
| | rasters["valid_mask_filled"] = self._get_valid_mask( |
| | rasters["depth_filled_linear"] |
| | ).clone() |
| |
|
| | other = {"index": index, "rgb_relative_path": rgb_rel_path} |
| |
|
| | return rasters, other |
| |
|
| | def _load_rgb_data(self, rgb_rel_path): |
| | |
| | rgb = self._read_rgb_file(rgb_rel_path) |
| |
|
| | outputs = { |
| | "rgb_int": torch.from_numpy(rgb).int(), |
| | } |
| | return outputs |
| |
|
| | def _load_depth_data(self, depth_rel_path, filled_rel_path): |
| | |
| | outputs = {} |
| | depth_raw = self._read_depth_file(depth_rel_path).squeeze() |
| | depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) |
| | outputs["depth_raw_linear"] = depth_raw_linear.clone() |
| |
|
| | if self.has_filled_depth: |
| | depth_filled = self._read_depth_file(filled_rel_path).squeeze() |
| | depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0) |
| | outputs["depth_filled_linear"] = depth_filled_linear |
| | else: |
| | outputs["depth_filled_linear"] = depth_raw_linear.clone() |
| |
|
| | return outputs |
| |
|
| | def _get_data_path(self, index): |
| | filename_line = self.filenames[index] |
| |
|
| | |
| | rgb_rel_path = filename_line[0] |
| |
|
| | depth_rel_path, filled_rel_path = None, None |
| | if DatasetMode.RGB_ONLY != self.mode: |
| | depth_rel_path = filename_line[1] |
| | if self.has_filled_depth: |
| | filled_rel_path = filename_line[2] |
| | return rgb_rel_path, depth_rel_path, filled_rel_path |
| |
|
| | def _read_image(self, img_rel_path) -> np.ndarray: |
| | if self.is_tar: |
| | if self.tar_obj is None: |
| | self.tar_obj = tarfile.open(self.dataset_dir) |
| | image = self.tar_obj.extractfile("./" + img_rel_path) |
| | image = image.read() |
| | image = Image.open(io.BytesIO(image)) |
| | else: |
| | img_path = os.path.join(self.dataset_dir, img_rel_path) |
| | image = Image.open(img_path).convert('RGB') |
| | image = np.asarray(image) |
| | return image |
| |
|
| | def _read_depth_cv2(self, img_rel_path) -> np.ndarray: |
| | depth_path = os.path.join(self.dataset_dir, img_rel_path) |
| | depth_in = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) |
| | if depth_in.shape[2] == 3: |
| | depth_in = depth_in[..., 0] |
| | depth_in = depth_in.astype(np.float32) |
| | return depth_in |
| |
|
| | def _read_rgb_file(self, rel_path) -> np.ndarray: |
| | rgb = self._read_image(rel_path) |
| | |
| | if rgb.shape[2] == 4: |
| | rgb = rgb[:, :, :3] |
| | rgb = np.transpose(rgb, (2, 0, 1)).astype(int) |
| | return rgb |
| |
|
| | def _read_depth_file(self, rel_path): |
| | depth_in = self._read_image(rel_path) |
| | |
| | depth_decoded = depth_in |
| |
|
| | return depth_decoded |
| |
|
| | def _get_valid_mask(self, depth: torch.Tensor): |
| | valid_mask = torch.logical_and( |
| | (depth > self.min_depth), (depth < self.max_depth) |
| | ).bool() |
| | return valid_mask |
| |
|
| | def _training_preprocess(self, rasters): |
| | |
| | if self.augm_args is not None: |
| | rasters = self._augment_data(rasters) |
| |
|
| | |
| | rasters["depth_raw_norm"] = self.depth_transform( |
| | rasters["depth_raw_linear"], rasters["valid_mask_raw"] |
| | ).clone() |
| | rasters["depth_filled_norm"] = self.depth_transform( |
| | rasters["depth_filled_linear"], rasters["valid_mask_filled"] |
| | ).clone() |
| |
|
| | |
| | if self.move_invalid_to_far_plane: |
| | if self.depth_transform.far_plane_at_max: |
| | rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( |
| | self.depth_transform.norm_max |
| | ) |
| | else: |
| | rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( |
| | self.depth_transform.norm_min |
| | ) |
| |
|
| | |
| | if self.resize_to_hw is not None: |
| | resize_transform = Resize( |
| | size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT |
| | ) |
| | rasters = {k: resize_transform(v) for k, v in rasters.items()} |
| |
|
| | return rasters |
| |
|
| | def _augment_data(self, rasters_dict): |
| | |
| | lr_flip_p = self.augm_args.lr_flip_p |
| | if random.random() < lr_flip_p: |
| | rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} |
| |
|
| | return rasters_dict |
| |
|
| | def __del__(self): |
| | if self.tar_obj is not None: |
| | self.tar_obj.close() |
| | self.tar_obj = None |
| |
|
| |
|
| | |
| | class DepthFileNameMode(Enum): |
| | id = 1 |
| | rgb_id = 2 |
| | i_d_rgb = 3 |
| | rgb_i_d = 4 |
| |
|
| |
|
| | def get_pred_name(rgb_basename, name_mode, suffix=".png"): |
| | if DepthFileNameMode.rgb_id == name_mode: |
| | pred_basename = "pred_" + rgb_basename.split("_")[1] |
| | elif DepthFileNameMode.i_d_rgb == name_mode: |
| | pred_basename = rgb_basename.replace("_rgb.", "_pred.") |
| | elif DepthFileNameMode.id == name_mode: |
| | pred_basename = "pred_" + rgb_basename |
| | elif DepthFileNameMode.rgb_i_d == name_mode: |
| | pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) |
| | else: |
| | raise NotImplementedError |
| | |
| | pred_basename = os.path.splitext(pred_basename)[0] + suffix |
| |
|
| | return pred_basename |
| |
|