| | from concurrent.futures import ThreadPoolExecutor, as_completed |
| | import json |
| | from dataclasses import dataclass |
| | from functools import cached_property |
| | from pathlib import Path |
| | import random |
| | from typing import Literal |
| | import os |
| | import numpy as np |
| | import torch |
| | import torchvision.transforms as tf |
| | from einops import rearrange, repeat |
| | from jaxtyping import Float, UInt8 |
| | from PIL import Image |
| | from torch import Tensor |
| | from torch.utils.data import Dataset |
| | import os.path as osp |
| | import cv2 |
| | from ..geometry.projection import get_fov |
| | from .dataset import DatasetCfgCommon |
| | from .shims.augmentation_shim import apply_augmentation_shim |
| | from .shims.crop_shim import apply_crop_shim |
| | from .types import Stage |
| | from .view_sampler import ViewSampler |
| | from ..misc.cam_utils import camera_normalization |
| |
|
| | from .shims.geometry_shim import depthmap_to_absolute_camera_coordinates |
| |
|
| | CATEGORY = {'train': |
| | ["backpack", "ball", "banana", "baseballbat", "baseballglove", |
| | "bench", "bicycle", "book", "bottle", "bowl", "broccoli", "cake", "car", "carrot", |
| | "cellphone", "chair", "couch", "cup", "donut", "frisbee", "hairdryer", "handbag", |
| | "hotdog", "hydrant", "keyboard", "kite", "laptop", "microwave", |
| | "motorcycle", |
| | "mouse", "orange", "parkingmeter", "pizza", "plant", "remote", "sandwich", |
| | "skateboard", "stopsign", |
| | "suitcase", "teddybear", "toaster", "toilet", "toybus", |
| | "toyplane", "toytrain", "toytruck", "tv", |
| | "umbrella", "vase", "wineglass",], |
| | 'test': ['teddybear']} |
| |
|
| | @dataclass |
| | class DatasetCo3dCfg(DatasetCfgCommon): |
| | name: str |
| | roots: list[Path] |
| | baseline_min: float |
| | baseline_max: float |
| | max_fov: float |
| | make_baseline_1: bool |
| | augment: bool |
| | relative_pose: bool |
| | skip_bad_shape: bool |
| | normalize_by_pts3d: bool |
| | intr_augment: bool |
| | rescale_to_1cube: bool |
| | mask_bg: Literal['rand', True, False] = True |
| | |
| | @dataclass |
| | class DatasetCo3dCfgWrapper: |
| | co3d: DatasetCo3dCfg |
| |
|
| |
|
| | class DatasetCo3d(Dataset): |
| | cfg: DatasetCo3dCfg |
| | stage: Stage |
| | view_sampler: ViewSampler |
| |
|
| | to_tensor: tf.ToTensor |
| | chunks: list[Path] |
| | near: float = 0.1 |
| | far: float = 100.0 |
| |
|
| | def __init__( |
| | self, |
| | cfg: DatasetCo3dCfg, |
| | stage: Stage, |
| | view_sampler: ViewSampler, |
| | ) -> None: |
| | super().__init__() |
| | self.cfg = cfg |
| | self.stage = stage |
| | self.view_sampler = view_sampler |
| | self.to_tensor = tf.ToTensor() |
| |
|
| | self.root = cfg.roots[0] |
| | self.mask_bg = cfg.mask_bg |
| | assert self.mask_bg in ('rand', True, False) |
| |
|
| | |
| | self.categories = CATEGORY[self.data_stage] |
| | self.scene_seq_dict = {} |
| | self.scene_ids = [] |
| | for category in self.categories: |
| | with open(osp.join(self.root, f"{category}/valid_seq.json"), "r") as f: |
| | scene_seq_dict = json.load(f) |
| | for scene, seqs in scene_seq_dict.items(): |
| | self.scene_seq_dict[f"{category}/{scene}"] = seqs |
| | self.scene_ids.append(f"{category}/{scene}") |
| |
|
| | print(f"CO3Dv2 {self.stage}: loaded {len(self.scene_seq_dict)} scenes") |
| |
|
| | def load_frames(self, scene_id, frame_ids): |
| | with ThreadPoolExecutor(max_workers=32) as executor: |
| | |
| | futures_with_idx = [] |
| | for idx, frame_id in enumerate(frame_ids): |
| | file_path = os.path.join(self.root, f"{scene_id}/images/frame{frame_id:06d}.jpg") |
| | futures_with_idx.append( |
| | ( |
| | idx, |
| | executor.submit( |
| | lambda p: self.to_tensor(Image.open(p).convert("RGB")), |
| | file_path, |
| | ), |
| | ) |
| | ) |
| |
|
| | |
| | torch_images = [None] * len(frame_ids) |
| | for idx, future in futures_with_idx: |
| | torch_images[idx] = future.result() |
| | |
| | sizes = set(img.shape for img in torch_images) |
| | if len(sizes) == 1: |
| | torch_images = torch.stack(torch_images) |
| | |
| | return torch_images |
| |
|
| | def load_npz(self, scene_id, frame_id): |
| | npzpath = os.path.join(self.root, f"{scene_id}/images/frame{frame_id:06d}.npz") |
| | imgpath = os.path.join(self.root, f"{scene_id}/images/frame{frame_id:06d}.jpg") |
| | img = Image.open(imgpath) |
| | |
| | W, H = img.size |
| | npzdata = np.load(npzpath) |
| | intri = npzdata['camera_intrinsics'] |
| | extri = npzdata['camera_pose'] |
| | intri[0, 0] /= float(W) |
| | intri[1, 1] /= float(H) |
| | intri[0, 2] /= float(W) |
| | intri[1, 2] /= float(H) |
| | md = npzdata['maximum_depth'] |
| | return intri, extri, md |
| |
|
| | def load_depth(self, scene_id, frame_ids, mds): |
| | torch_depths = [] |
| | for frame_id in frame_ids: |
| | depthpath = os.path.join(self.root, f"{scene_id}/depths/frame{frame_id:06d}.jpg.geometric.png") |
| | depth = cv2.imread(depthpath, cv2.IMREAD_UNCHANGED)/65535*np.nan_to_num(mds[frame_id]) |
| | depth = np.nan_to_num(depth) |
| | torch_depths.append(torch.from_numpy(depth)) |
| | return torch_depths |
| | |
| | def load_masks(self, scene_id, frame_ids): |
| | masks = [] |
| | for frame_id in frame_ids: |
| | maskpath = os.path.join(self.root, f"{scene_id}/masks/frame{frame_id:06d}.png") |
| | maskmap = cv2.imread(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32) |
| | maskmap = (maskmap / 255.0) > 0.1 |
| | masks.append(torch.from_numpy(maskmap)) |
| | return masks |
| |
|
| | def getitem(self, index: int, num_context_views: int, patchsize: tuple) -> dict: |
| | scene_id = self.scene_ids[index] |
| | seq = self.scene_seq_dict[scene_id] |
| |
|
| | extrinsics = [] |
| | intrinsics = [] |
| | frame_ids = [] |
| | mds = {} |
| | for frame_id in seq: |
| | intri, extri, md = self.load_npz(scene_id, frame_id) |
| | extrinsics.append(extri) |
| | intrinsics.append(intri) |
| | frame_ids.append(frame_id) |
| | mds[frame_id] = md |
| |
|
| | extrinsics = np.array(extrinsics) |
| | intrinsics = np.array(intrinsics) |
| | extrinsics = torch.tensor(extrinsics, dtype=torch.float32) |
| | intrinsics = torch.tensor(intrinsics, dtype=torch.float32) |
| | |
| | num_views = extrinsics.shape[0] |
| | context_indices = torch.tensor(random.sample(range(num_views), num_context_views)) |
| | remaining_indices = torch.tensor([i for i in range(num_views) if i not in context_indices]) |
| | target_indices = torch.tensor(random.sample(remaining_indices.tolist(), self.view_sampler.num_target_views)) |
| |
|
| | |
| | if (get_fov(intrinsics).rad2deg() > self.cfg.max_fov).any(): |
| | raise Exception("Field of view too wide") |
| |
|
| | input_frames = [frame_ids[i] for i in context_indices] |
| | target_frame = [frame_ids[i] for i in target_indices] |
| |
|
| | context_images = self.load_frames(scene_id, input_frames) |
| | target_images = self.load_frames(scene_id, target_frame) |
| | context_depths = self.load_depth(scene_id, input_frames, mds) |
| | target_depths = self.load_depth(scene_id, target_frame, mds) |
| |
|
| | mask_bg = (self.mask_bg == True) or (self.mask_bg == "rand" and np.random.random() < 0.5) |
| | if mask_bg: |
| | context_masks = self.load_masks(scene_id, input_frames) |
| | target_mask = self.load_masks(scene_id, target_frame) |
| |
|
| | |
| | context_depths = [depth * mask for depth, mask in zip(context_depths, context_masks)] |
| | target_depths = [depth * mask for depth, mask in zip(target_depths, target_mask)] |
| |
|
| |
|
| | |
| | context_extrinsics = extrinsics[context_indices] |
| | if self.cfg.make_baseline_1: |
| | a, b = context_extrinsics[0, :3, 3], context_extrinsics[-1, :3, 3] |
| | scale = (a - b).norm() |
| | if scale < self.cfg.baseline_min or scale > self.cfg.baseline_max: |
| | print( |
| | f"Skipped {scene_id} because of baseline out of range: " |
| | f"{scale:.6f}" |
| | ) |
| | raise Exception("baseline out of range") |
| | extrinsics[:, :3, 3] /= scale |
| | else: |
| | scale = 1 |
| |
|
| | if self.cfg.relative_pose: |
| | extrinsics = camera_normalization(extrinsics[context_indices][0:1], extrinsics) |
| |
|
| | |
| | if self.cfg.rescale_to_1cube: |
| | scene_scale = torch.max(torch.abs(extrinsics[context_indices][:, :3, 3])) |
| | |
| | |
| | rescale_factor = 1 * scene_scale |
| | extrinsics[:, :3, 3] /= rescale_factor |
| |
|
| | example = { |
| | "context": { |
| | "extrinsics": extrinsics[context_indices], |
| | "intrinsics": intrinsics[context_indices], |
| | "image": context_images, |
| | "depth": context_depths, |
| | "near": self.get_bound("near", len(context_indices)), |
| | "far": self.get_bound("far", len(context_indices)), |
| | "index": context_indices, |
| | |
| | }, |
| | "target": { |
| | "extrinsics": extrinsics[target_indices], |
| | "intrinsics": intrinsics[target_indices], |
| | "image": target_images, |
| | "depth": target_depths, |
| | "near": self.get_bound("near", len(target_indices)), |
| | "far": self.get_bound("far", len(target_indices)), |
| | "index": target_indices, |
| | }, |
| | "scene": f"CO3Dv2 {scene_id}", |
| | } |
| |
|
| | if self.stage == "train" and self.cfg.intr_augment: |
| | intr_aug = True |
| | else: |
| | intr_aug = False |
| |
|
| | example = apply_crop_shim(example, (patchsize[0] * 14, patchsize[1] * 14), intr_aug=intr_aug) |
| | |
| | if self.stage == "train" and self.cfg.augment: |
| | example = apply_augmentation_shim(example) |
| |
|
| | |
| | |
| | image_size = example["context"]["image"].shape[2:] |
| | context_intrinsics = example["context"]["intrinsics"].clone().detach().numpy() |
| | context_intrinsics[:, 0] = context_intrinsics[:, 0] * image_size[1] |
| | context_intrinsics[:, 1] = context_intrinsics[:, 1] * image_size[0] |
| |
|
| | target_intrinsics = example["target"]["intrinsics"].clone().detach().numpy() |
| | target_intrinsics[:, 0] = target_intrinsics[:, 0] * image_size[1] |
| | target_intrinsics[:, 1] = target_intrinsics[:, 1] * image_size[0] |
| |
|
| | context_pts3d_list, context_valid_mask_list = [], [] |
| | target_pts3d_list, target_valid_mask_list = [], [] |
| |
|
| | for i in range(len(example["context"]["depth"])): |
| | context_pts3d, context_valid_mask = depthmap_to_absolute_camera_coordinates(example["context"]["depth"][i].numpy(), context_intrinsics[i], example["context"]["extrinsics"][i].numpy()) |
| | context_pts3d_list.append(torch.from_numpy(context_pts3d).to(torch.float32)) |
| | context_valid_mask_list.append(torch.from_numpy(context_valid_mask)) |
| |
|
| | context_pts3d = torch.stack(context_pts3d_list, dim=0) |
| | context_valid_mask = torch.stack(context_valid_mask_list, dim=0) |
| |
|
| | for i in range(len(example["target"]["depth"])): |
| | target_pts3d, target_valid_mask = depthmap_to_absolute_camera_coordinates(example["target"]["depth"][i].numpy(), target_intrinsics[i], example["target"]["extrinsics"][i].numpy()) |
| | target_pts3d_list.append(torch.from_numpy(target_pts3d).to(torch.float32)) |
| | target_valid_mask_list.append(torch.from_numpy(target_valid_mask)) |
| |
|
| | target_pts3d = torch.stack(target_pts3d_list, dim=0) |
| | target_valid_mask = torch.stack(target_valid_mask_list, dim=0) |
| |
|
| | |
| | if self.cfg.normalize_by_pts3d: |
| | transformed_pts3d = context_pts3d[context_valid_mask] |
| | scene_factor = transformed_pts3d.norm(dim=-1).mean().clip(min=1e-8) |
| | |
| | context_pts3d /= scene_factor |
| | example["context"]["depth"] /= scene_factor |
| | example["context"]["extrinsics"][:, :3, 3] /= scene_factor |
| | |
| | target_pts3d /= scene_factor |
| | example["target"]["depth"] /= scene_factor |
| | example["target"]["extrinsics"][:, :3, 3] /= scene_factor |
| |
|
| | example["context"]["pts3d"] = context_pts3d |
| | example["target"]["pts3d"] = target_pts3d |
| | example["context"]["valid_mask"] = context_valid_mask |
| | example["target"]["valid_mask"] = target_valid_mask |
| |
|
| | if torch.isnan(example["context"]["depth"]).any() or torch.isinf(example["context"]["depth"]).any() or \ |
| | torch.isnan(example["context"]["extrinsics"]).any() or torch.isinf(example["context"]["extrinsics"]).any() or \ |
| | torch.isnan(example["context"]["pts3d"]).any() or torch.isinf(example["context"]["pts3d"]).any() or \ |
| | torch.isnan(example["context"]["intrinsics"]).any() or torch.isinf(example["context"]["intrinsics"]).any() or \ |
| | torch.isnan(example["target"]["depth"]).any() or torch.isinf(example["target"]["depth"]).any() or \ |
| | torch.isnan(example["target"]["extrinsics"]).any() or torch.isinf(example["target"]["extrinsics"]).any() or \ |
| | torch.isnan(example["target"]["pts3d"]).any() or torch.isinf(example["target"]["pts3d"]).any() or \ |
| | torch.isnan(example["target"]["intrinsics"]).any() or torch.isinf(example["target"]["intrinsics"]).any(): |
| | raise Exception("encounter nan or inf in context depth") |
| |
|
| | for key in ["context", "target"]: |
| | example[key]["valid_mask"] = (torch.ones_like(example[key]["valid_mask"]) * -1).type(torch.int32) |
| |
|
| | return example |
| |
|
| |
|
| | def __getitem__(self, index_tuple: tuple) -> dict: |
| | index, num_context_views, patchsize_h = index_tuple |
| | patchsize_w = (self.cfg.input_image_shape[1] // 14) |
| | try: |
| | return self.getitem(index, num_context_views, (patchsize_h, patchsize_w)) |
| | except Exception as e: |
| | print(f"Error: {e}") |
| | index = np.random.randint(len(self)) |
| | return self.__getitem__((index, num_context_views, patchsize_h)) |
| |
|
| | def get_bound( |
| | self, |
| | bound: Literal["near", "far"], |
| | num_views: int, |
| | ) -> Float[Tensor, " view"]: |
| | value = torch.tensor(getattr(self, bound), dtype=torch.float32) |
| | return repeat(value, "-> v", v=num_views) |
| |
|
| | @property |
| | def data_stage(self) -> Stage: |
| | if self.cfg.overfit_to_scene is not None: |
| | return "test" |
| | if self.stage == "val": |
| | return "test" |
| | return self.stage |
| |
|
| | @cached_property |
| | def index(self) -> dict[str, Path]: |
| | merged_index = {} |
| | data_stages = [self.data_stage] |
| | if self.cfg.overfit_to_scene is not None: |
| | data_stages = ("test", "train") |
| | for data_stage in data_stages: |
| | for root in self.cfg.roots: |
| | |
| | with (root / data_stage / "index.json").open("r") as f: |
| | index = json.load(f) |
| | index = {k: Path(root / data_stage / v) for k, v in index.items()} |
| |
|
| | |
| | assert not (set(merged_index.keys()) & set(index.keys())) |
| |
|
| | |
| | merged_index = {**merged_index, **index} |
| | return merged_index |
| |
|
| | def __len__(self) -> int: |
| | return len(self.scene_ids) |