| import json |
| from dataclasses import dataclass |
| from functools import cached_property |
| from io import BytesIO |
| from pathlib import Path |
| import random |
| from typing import Literal |
| import os |
| import cv2 |
| import numpy as np |
| import torch |
| import torchvision.transforms as tf |
| from einops import rearrange, repeat |
| from jaxtyping import Float, UInt8 |
| from PIL import Image |
| import torchvision |
| from torch import Tensor |
| from torch.utils.data import Dataset |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| import copy |
| from .shims.geometry_shim import depthmap_to_absolute_camera_coordinates |
|
|
| from .shims.load_shim import imread_cv2 |
|
|
| from ..geometry.projection import get_fov |
| from .dataset import DatasetCfgCommon |
| from .shims.augmentation_shim import apply_augmentation_shim |
| from .shims.crop_shim import apply_crop_shim |
| from .types import Stage |
| from .view_sampler import ViewSampler |
| from ..misc.cam_utils import camera_normalization |
|
|
| @dataclass |
| class DatasetScannetppCfg(DatasetCfgCommon): |
| name: str |
| roots: list[Path] |
| baseline_min: float |
| baseline_max: float |
| max_fov: float |
| make_baseline_1: bool |
| augment: bool |
| relative_pose: bool |
| skip_bad_shape: bool |
| metric_thre: float |
| intr_augment: bool |
| make_baseline_1: bool |
| rescale_to_1cube: bool |
| normalize_by_pts3d: bool |
|
|
| @dataclass |
| class DatasetScannetppCfgWrapper: |
| scannetpp: DatasetScannetppCfg |
|
|
|
|
| class DatasetScannetpp(Dataset): |
| cfg: DatasetScannetppCfgWrapper |
| stage: Stage |
| view_sampler: ViewSampler |
|
|
| to_tensor: tf.ToTensor |
| chunks: list[Path] |
| near: float = 0.1 |
| far: float = 100.0 |
| |
| def __init__( |
| self, |
| cfg: DatasetScannetppCfgWrapper, |
| stage: Stage, |
| view_sampler: ViewSampler, |
| ) -> None: |
| super().__init__() |
| self.cfg = cfg |
| self.stage = stage |
| self.view_sampler = view_sampler |
| self.to_tensor = tf.ToTensor() |
|
|
| |
| self.data_root = cfg.roots[0] |
| self.data_list = [] |
| data_index = os.listdir(f"{self.data_root}") |
|
|
| if self.stage != "train": |
| with open(f"{self.data_root}/valid.json", "r") as file: |
| data_index = json.load(file)[:10] |
| data_index = data_index * 100 |
| random.shuffle(data_index) |
| else: |
| with open(f"{self.data_root}/valid.json", "r") as file: |
| data_index = json.load(file)[10:] |
|
|
| self.data_list = [ |
| os.path.join(self.data_root, item) for item in data_index |
| ] |
| |
| self.scene_ids = {} |
| self.scenes = {} |
| index = 0 |
| with ThreadPoolExecutor(max_workers=32) as executor: |
| futures = [executor.submit(self.load_metadata, scene_path) for scene_path in self.data_list] |
| for future in as_completed(futures): |
| scene_frames, scene_id = future.result() |
| self.scenes[scene_id] = scene_frames |
| self.scene_ids[index] = scene_id |
| index += 1 |
| |
| |
| |
| |
| print(f"Scannetpp: {self.stage}: loaded {len(self.scene_ids)} scenes") |
| |
| def shuffle(self, lst: list) -> list: |
| indices = torch.randperm(len(lst)) |
| return [lst[x] for x in indices] |
|
|
| def load_metadata(self, scene_path): |
| metadata_path = os.path.join(scene_path, "scene_metadata.npz") |
| metadata = np.load(metadata_path, allow_pickle=True) |
| intrinsics = metadata["intrinsics"] |
| trajectories = metadata["trajectories"] |
| images = metadata["images"] |
|
|
| scene_id = scene_path.split("/")[-1].split(".")[0] |
| scene_frames = [ |
| { |
| "file_path": os.path.join(scene_path, "images", images[i].split(".")[0] + ".jpg"), |
| "depth_path": os.path.join(scene_path, "depth", images[i].split(".")[0] + ".png"), |
| "intrinsics": self.convert_intrinsics(intrinsics[i]), |
| "extrinsics": trajectories[i], |
| } |
| for i in range(len(images)) |
| ] |
| scene_frames.sort(key=lambda x: x["file_path"]) |
| return scene_frames, scene_id |
|
|
| def convert_intrinsics(self, intrinsics): |
| w = intrinsics[0, 2] * 2 |
| h = intrinsics[1, 2] * 2 |
| intrinsics[0, 0] = intrinsics[0, 0] / w |
| intrinsics[1, 1] = intrinsics[1, 1] / h |
| intrinsics[0, 2] = intrinsics[0, 2] / w |
| intrinsics[1, 2] = intrinsics[1, 2] / h |
| return intrinsics |
| |
| def blender2opencv_c2w(self, pose): |
| blender2opencv = np.array( |
| [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] |
| ) |
| opencv_c2w = np.array(pose) @ blender2opencv |
| return opencv_c2w.tolist() |
|
|
| def load_frames(self, frames): |
| with ThreadPoolExecutor(max_workers=1) as executor: |
| |
| futures_with_idx = [] |
| for idx, file_path in enumerate(frames): |
| file_path = file_path["file_path"] |
| futures_with_idx.append( |
| ( |
| idx, |
| executor.submit( |
| lambda p: self.to_tensor(Image.open(p).convert("RGB")), |
| file_path, |
| ), |
| ) |
| ) |
| |
| |
| torch_images = [None] * len(frames) |
| for idx, future in futures_with_idx: |
| torch_images[idx] = future.result() |
| |
| sizes = set(img.shape for img in torch_images) |
| if len(sizes) == 1: |
| torch_images = torch.stack(torch_images) |
| |
| return torch_images |
| |
| def load_depths(self, frames): |
| torch_depths = [] |
| for idx, frame in enumerate(frames): |
| depthmap = imread_cv2(frame["depth_path"], cv2.IMREAD_UNCHANGED) |
| depthmap = depthmap.astype(np.float32) / 1000 |
| depthmap[~np.isfinite(depthmap)] = 0 |
| torch_depths.append(torch.from_numpy(depthmap)) |
| return torch.stack(torch_depths) |
|
|
|
|
| def getitem(self, index: int, num_context_views: int, patchsize: tuple) -> dict: |
| |
| |
|
|
| scene = self.scene_ids[index] |
| example = self.scenes[scene] |
| |
| extrinsics = [] |
| intrinsics = [] |
| for frame in example: |
| extrinsic = frame["extrinsics"] |
| intrinsic = frame["intrinsics"] |
| extrinsics.append(extrinsic) |
| intrinsics.append(intrinsic) |
| |
| extrinsics = np.array(extrinsics) |
| intrinsics = np.array(intrinsics) |
| extrinsics = torch.tensor(extrinsics, dtype=torch.float32) |
| intrinsics = torch.tensor(intrinsics, dtype=torch.float32) |
|
|
| try: |
| context_indices, target_indices, overlap = self.view_sampler.sample( |
| "scannetpp_"+scene, |
| num_context_views, |
| extrinsics, |
| intrinsics, |
| ) |
| except ValueError: |
| |
| raise Exception("Not enough frames") |
| |
| |
| if (get_fov(intrinsics).rad2deg() > self.cfg.max_fov).any(): |
| raise Exception("Field of view too wide") |
|
|
| |
| input_frames = [example[i] for i in context_indices] |
| target_frame = [example[i] for i in target_indices] |
| |
| context_images = self.load_frames(input_frames) |
| target_images = self.load_frames(target_frame) |
|
|
| context_depths = self.load_depths(input_frames) |
| target_depths = self.load_depths(target_frame) |
| |
| |
| context_image_invalid = context_images.shape[1:] != (3, *self.cfg.original_image_shape) |
| target_image_invalid = target_images.shape[1:] != (3, *self.cfg.original_image_shape) |
| if self.cfg.skip_bad_shape and (context_image_invalid or target_image_invalid): |
| print( |
| f"Skipped bad example {example['key']}. Context shape was " |
| f"{context_images.shape} and target shape was " |
| f"{target_images.shape}." |
| ) |
| raise Exception("Bad example image shape") |
| |
| context_extrinsics = extrinsics[context_indices] |
|
|
| if self.cfg.make_baseline_1: |
| a, b = context_extrinsics[0, :3, 3], context_extrinsics[-1, :3, 3] |
| scale = (a - b).norm() |
| if scale < self.cfg.baseline_min or scale > self.cfg.baseline_max: |
| print( |
| f"Skipped {scene} because of baseline out of range: " |
| f"{scale:.6f}" |
| ) |
| raise Exception("baseline out of range") |
| extrinsics[:, :3, 3] /= scale |
| else: |
| scale = 1 |
|
|
| if self.cfg.relative_pose: |
| extrinsics = camera_normalization(extrinsics[context_indices][0:1], extrinsics) |
|
|
| if self.cfg.rescale_to_1cube: |
| scene_scale = torch.max(torch.abs(extrinsics[context_indices][:, :3, 3])) |
| rescale_factor = 1 * scene_scale |
| extrinsics[:, :3, 3] /= rescale_factor |
|
|
| if torch.isnan(extrinsics).any() or torch.isinf(extrinsics).any(): |
| raise Exception("encounter nan or inf in input poses") |
|
|
| example = { |
| "context": { |
| "extrinsics": extrinsics[context_indices], |
| "intrinsics": intrinsics[context_indices], |
| "image": context_images, |
| "depth": context_depths, |
| "near": self.get_bound("near", len(context_indices)) / scale, |
| "far": self.get_bound("far", len(context_indices)) / scale, |
| "index": context_indices, |
| "overlap": overlap, |
| }, |
| "target": { |
| "extrinsics": extrinsics[target_indices], |
| "intrinsics": intrinsics[target_indices], |
| "image": target_images, |
| "depth": target_depths, |
| "near": self.get_bound("near", len(target_indices)) / scale, |
| "far": self.get_bound("far", len(target_indices)) / scale, |
| "index": target_indices, |
| }, |
| "scene": f"Scannetpp {scene}", |
| } |
| if self.stage == "train" and self.cfg.augment: |
| example = apply_augmentation_shim(example) |
|
|
| if self.stage == "train" and self.cfg.intr_augment: |
| intr_aug = True |
| else: |
| intr_aug = False |
|
|
| example = apply_crop_shim(example, (patchsize[0] * 14, patchsize[1] * 14), intr_aug=intr_aug) |
| |
| |
| image_size = example["context"]["image"].shape[2:] |
| context_intrinsics = example["context"]["intrinsics"].clone().detach().numpy() |
| context_intrinsics[:, 0] = context_intrinsics[:, 0] * image_size[1] |
| context_intrinsics[:, 1] = context_intrinsics[:, 1] * image_size[0] |
|
|
| target_intrinsics = example["target"]["intrinsics"].clone().detach().numpy() |
| target_intrinsics[:, 0] = target_intrinsics[:, 0] * image_size[1] |
| target_intrinsics[:, 1] = target_intrinsics[:, 1] * image_size[0] |
|
|
| context_pts3d_list, context_valid_mask_list = [], [] |
| target_pts3d_list, target_valid_mask_list = [], [] |
|
|
| for i in range(len(example["context"]["depth"])): |
| context_pts3d, context_valid_mask = depthmap_to_absolute_camera_coordinates(example["context"]["depth"][i].numpy(), context_intrinsics[i], example["context"]["extrinsics"][i].numpy()) |
| context_pts3d_list.append(torch.from_numpy(context_pts3d).to(torch.float32)) |
| context_valid_mask_list.append(torch.from_numpy(context_valid_mask)) |
| |
| context_pts3d = torch.stack(context_pts3d_list, dim=0) |
| context_valid_mask = torch.stack(context_valid_mask_list, dim=0) |
|
|
| for i in range(len(example["target"]["depth"])): |
| target_pts3d, target_valid_mask = depthmap_to_absolute_camera_coordinates(example["target"]["depth"][i].numpy(), target_intrinsics[i], example["target"]["extrinsics"][i].numpy()) |
| target_pts3d_list.append(torch.from_numpy(target_pts3d).to(torch.float32)) |
| target_valid_mask_list.append(torch.from_numpy(target_valid_mask)) |
|
|
| target_pts3d = torch.stack(target_pts3d_list, dim=0) |
| target_valid_mask = torch.stack(target_valid_mask_list, dim=0) |
| |
| |
| if self.cfg.normalize_by_pts3d: |
| transformed_pts3d = context_pts3d[context_valid_mask] |
| scene_factor = transformed_pts3d.norm(dim=-1).mean().clip(min=1e-8) |
| context_pts3d /= scene_factor |
| example["context"]["depth"] /= scene_factor |
| example["context"]["extrinsics"][:, :3, 3] /= scene_factor |
| |
| target_pts3d /= scene_factor |
| example["target"]["depth"] /= scene_factor |
| example["target"]["extrinsics"][:, :3, 3] /= scene_factor |
|
|
| example["context"]["pts3d"] = context_pts3d |
| example["target"]["pts3d"] = target_pts3d |
| example["context"]["valid_mask"] = context_valid_mask |
| example["target"]["valid_mask"] = target_valid_mask |
|
|
| if torch.isnan(example["context"]["depth"]).any() or torch.isinf(example["context"]["depth"]).any() or \ |
| torch.isnan(example["context"]["extrinsics"]).any() or torch.isinf(example["context"]["extrinsics"]).any() or \ |
| torch.isnan(example["context"]["intrinsics"]).any() or torch.isinf(example["context"]["intrinsics"]).any() or \ |
| torch.isnan(example["target"]["depth"]).any() or torch.isinf(example["target"]["depth"]).any() or \ |
| torch.isnan(example["target"]["extrinsics"]).any() or torch.isinf(example["target"]["extrinsics"]).any() or \ |
| torch.isnan(example["target"]["intrinsics"]).any() or torch.isinf(example["target"]["intrinsics"]).any(): |
| raise Exception("encounter nan or inf in context depth") |
| |
| for key in ["context", "target"]: |
| example[key]["valid_mask"] = (torch.ones_like(example[key]["valid_mask"]) * -1).type(torch.int32) |
| |
| return example |
| |
| |
| def __getitem__(self, index_tuple: tuple) -> dict: |
| index, num_context_views, patchsize_h = index_tuple |
| |
| patchsize_w = (self.cfg.input_image_shape[1] // 14) |
| try: |
| return self.getitem(index, num_context_views, (patchsize_h, patchsize_w)) |
| except Exception as e: |
| print(f"Error: {e}") |
| index = np.random.randint(len(self)) |
| return self.__getitem__((index, num_context_views, patchsize_h)) |
|
|
| def convert_poses( |
| self, |
| poses: Float[Tensor, "batch 18"], |
| ) -> tuple[ |
| Float[Tensor, "batch 4 4"], |
| Float[Tensor, "batch 3 3"], |
| ]: |
| b, _ = poses.shape |
|
|
| |
| intrinsics = torch.eye(3, dtype=torch.float32) |
| intrinsics = repeat(intrinsics, "h w -> b h w", b=b).clone() |
| fx, fy, cx, cy = poses[:, :4].T |
| intrinsics[:, 0, 0] = fx |
| intrinsics[:, 1, 1] = fy |
| intrinsics[:, 0, 2] = cx |
| intrinsics[:, 1, 2] = cy |
| |
| |
| w2c = repeat(torch.eye(4, dtype=torch.float32), "h w -> b h w", b=b).clone() |
| w2c[:, :3] = rearrange(poses[:, 6:], "b (h w) -> b h w", h=3, w=4) |
| return w2c.inverse(), intrinsics |
|
|
| def convert_images( |
| self, |
| images: list[UInt8[Tensor, "..."]], |
| ) -> Float[Tensor, "batch 3 height width"]: |
| torch_images = [] |
| for image in images: |
| image = Image.open(BytesIO(image.numpy().tobytes())) |
| torch_images.append(self.to_tensor(image)) |
| return torch.stack(torch_images) |
|
|
| def get_bound( |
| self, |
| bound: Literal["near", "far"], |
| num_views: int, |
| ) -> Float[Tensor, " view"]: |
| value = torch.tensor(getattr(self, bound), dtype=torch.float32) |
| return repeat(value, "-> v", v=num_views) |
|
|
| @property |
| def data_stage(self) -> Stage: |
| if self.cfg.overfit_to_scene is not None: |
| return "test" |
| if self.stage == "val": |
| return "test" |
| return self.stage |
|
|
| @cached_property |
| def index(self) -> dict[str, Path]: |
| merged_index = {} |
| data_stages = [self.data_stage] |
| if self.cfg.overfit_to_scene is not None: |
| data_stages = ("test", "train") |
| for data_stage in data_stages: |
| for root in self.cfg.roots: |
| |
| with (root / data_stage / "index.json").open("r") as f: |
| index = json.load(f) |
| index = {k: Path(root / data_stage / v) for k, v in index.items()} |
|
|
| |
| assert not (set(merged_index.keys()) & set(index.keys())) |
|
|
| |
| merged_index = {**merged_index, **index} |
| return merged_index |
| |
| def __len__(self) -> int: |
| return len(self.data_list) |