import json import torchvision.transforms as T import torchvision.transforms.functional as TF import torch from tqdm import tqdm import os from glob import glob from torch.utils.data import Dataset from must3r.tools.image import get_resize_function from PIL import Image import numpy as np from einops import rearrange from typing import List, Dict, Optional, Tuple from pycocotools import mask as mask_utils import random, cv2 from scipy.spatial.transform import Rotation SAV_ANNOT_RATE = 4 # SA-V: annotations at 6 fps, video at 24 fps def load_images(folder_content, size, patch_size = 16, verbose = True): imgs = [] transform = ImgNorm = T.Compose([T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) resize_funcs = [] for content in folder_content: if isinstance(content, str): if verbose: print(f'Loading image from {content} ', end = '') rgb_image = Image.open(content).convert('RGB') elif isinstance(content, Image.Image): rgb_image = content else: raise ValueError(f'Unknown content type: {type(content)}') rgb_image.load() W, H = rgb_image.size resize_func, _, to_orig = get_resize_function(size, patch_size, H, W) resize_funcs.append(resize_func) rgb_tensor = resize_func(transform(rgb_image)) imgs.append(dict(img=rgb_tensor, true_shape=np.int32([rgb_tensor.shape[-2], rgb_tensor.shape[-1]]))) if verbose: print(f'with resolution {W}x{H} --> {rgb_tensor.shape[-1]}x{rgb_tensor.shape[-2]}') return imgs, resize_funcs def _decode_rle(rle: Dict, h: int, w: int) -> np.ndarray: if not rle or "counts" not in rle: return np.zeros((h, w), dtype=np.uint8) counts = rle["counts"] if isinstance(counts, str): counts = counts.encode("utf-8") m = mask_utils.decode({"size": [h, w], "counts": counts}) return (np.asarray(m).squeeze() > 0) def _read_frame_rgb(cap: cv2.VideoCapture, idx: int, fallback_hw: Optional[Tuple[int,int]]=None) -> np.ndarray: ok = cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) if not ok: raise RuntimeError(f"cv2.VideoCapture.set({idx}) failed") else: ok, bgr = cap.read() return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) class SAVTrainDataset(Dataset): """ SA-V train Dataset (mp4 + {video_id}_{manual|auto}.json). Scans JSON with pattern: root/*/*.json (non-recursive). __getitem__ follows the requested 1–5 procedure. """ def __init__( self, data_root: str, mask_type: Optional[str] = None, # None | "manual" | "auto" img_mean = (0.485, 0.456, 0.406), img_std = (0.229, 0.224, 0.225), N: int = 8, image_size: int = 1024, verbose: bool = False, max_stride: int = 1, # kept for parity, not used in this flow dataset_scale: int = 32, area_thresh: float = 0.01, # area ratio threshold at original HxW valid_must3r_sizes = [224, 512] ): assert mask_type in (None, "manual", "auto") assert N >= 1 self.verbose = verbose self.data_root = data_root self.dataset_scale = int(dataset_scale) self.N = int(N) self.mask_type = mask_type self.area_thresh = float(area_thresh) self.max_stride = int(max_stride) self.valid_must3r_sizes = valid_must3r_sizes self.image_transform = T.Compose([ T.Resize((image_size, image_size), interpolation=T.InterpolationMode.NEAREST_EXACT), T.Normalize(mean=img_mean, std=img_std), ]) self.instance_transform = T.Compose([ T.Resize((image_size, image_size), interpolation=T.InterpolationMode.NEAREST_EXACT), ]) # --- collect through JSONs (non-recursive) --- json_paths = glob(os.path.join(data_root, "*", "*.json")) self.items: List[Tuple[str, str]] = [] # (vpath, jpath) for jpath in tqdm(json_paths, desc="scanning jsons"): base = os.path.splitext(os.path.basename(jpath))[0] # filter by mask_type if specified if self.mask_type is not None and not base.endswith(f"_{self.mask_type}"): continue if base.endswith("_manual"): vid = base[:-7] elif base.endswith("_auto"): vid = base[:-5] else: # strictly require suffix continue vpath = os.path.join(os.path.dirname(jpath), f"{vid}.mp4") if os.path.isfile(vpath): self.items.append((vpath, jpath)) print(f"Collected {len(self.items)} video-json pairs") self._log_path = "./sav_dataset_resample.log" def __len__(self): return self.dataset_scale * len(self.items) def _resample(self): return self[random.randrange(len(self))] def _log(self, msg: str): try: with open(self._log_path, "a") as f: f.write(msg.rstrip() + "\n") except Exception: pass def __getitem__(self, idx: int): vpath, jpath = self.items[idx % len(self.items)] # 1) load json with open(jpath, "r") as f: meta = json.load(f) masklet: List[List[Dict]] = meta.get("masklet", []) if not isinstance(masklet, list) or len(masklet) < self.N: self._log(f"[short_json] {jpath}: len(masklet)={len(masklet)} < N={self.N}") return self._resample() H, W = int(meta["video_height"]), int(meta["video_width"]) # 2) randomly sample a center frame idx in masklet, build sample_indices = [idx-N, idx+N] center = random.randrange(len(masklet)) left = max(0, center - self.N * self.max_stride) right = min(len(masklet), center + self.N * self.max_stride) sample_indices = list(range(left, right)) if len(sample_indices) < self.N: self._log(f"[short_span] {jpath}: span={len(sample_indices)} < N={self.N}") return self._resample() obj_order = None while True: if len(sample_indices) < self.N: self._log(f"[exhausted_span] {jpath}: remaining span < N; resample") return self._resample() f0 = sample_indices[0] rles = masklet[f0] if isinstance(masklet[f0], list) else [] if len(rles) == 0: # no objects at this frame, pop and continue sample_indices.pop(0) continue obj_order = list(range(len(rles))) random.shuffle(obj_order) has_valid_id = False for oid in obj_order: m = _decode_rle(rles[oid], H, W) area = int(m.sum()) if area <= 0: continue ratio = area / float(H * W + 1e-6) if ratio >= self.area_thresh: has_valid_id = True break if has_valid_id: break else: # tried all object indices, none passed; pop first frame and continue sample_indices.pop(0) # downsample sample_indices to exactly N sample_indices = sample_indices[::min(len(sample_indices) // self.N, self.max_stride)][:self.N] assert len(sample_indices) == self.N # 5) similar to MOSE dataset: read frames, build masks only at anchor frame cap = cv2.VideoCapture(vpath) frames_rgb = [] frame_indices_24 = [] for f_annot in sample_indices: f24 = int(f_annot * SAV_ANNOT_RATE) frames_rgb.append(_read_frame_rgb(cap, f24, fallback_hw=(H, W))) frame_indices_24.append(f24) cap.release() # build original_images tensor [N, 3, H, W] original_imgs_pil = [Image.fromarray(fr) for fr in frames_rgb] # must3r parity fields must3r_size = np.random.choice(self.valid_must3r_sizes).item() views, resize_funcs = load_images(original_imgs_pil, size = must3r_size, patch_size = 16, verbose = self.verbose) original_instances = [] original_imgs = [] for frame_idx, (resize_func, sample_idx) in enumerate(zip(resize_funcs, sample_indices)): assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}' # assert resize_func.transforms[0].size[0] > resize_func.transforms[1].size[0], f'Expected first transform to be larger than second, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' # assert resize_func.transforms[0].size[1] / resize_func.transforms[1].size[1] == resize_func.transforms[0].size[0] / resize_func.transforms[1].size[0], f'Expected aspect ratio to be preserved, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' if frame_idx == 0: for instance_id in obj_order + [None]: if instance_id is None: return self._resample() if (resize_func.transforms[0](torch.from_numpy(_decode_rle(masklet[sample_idx][instance_id], H, W))).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * self.area_thresh)): break original_instances.append(resize_func.transforms[0](torch.from_numpy(_decode_rle(masklet[sample_idx][instance_id], H, W)))) original_imgs.append(resize_func.transforms[0](TF.to_tensor(original_imgs_pil[frame_idx]))) original_instances = torch.stack(original_instances).squeeze()[:, None] instances = self.instance_transform(original_instances) assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}' original_imgs = torch.stack(original_imgs) imgs = self.image_transform(original_imgs) return { "original_images": original_imgs, # [N,3,H,W] "images": imgs, # [N,3,S,S] "original_masks": original_instances, # [N,1,H,W] "masks": instances, # [N,1,S,S] "filelist": sample_indices, "must3r_views": views, "video": os.path.splitext(os.path.basename(vpath))[0], "instance_id": int(instance_id), "dataset": "sav", "valid_masks": torch.ones_like(instances), # [N,1,S,S] "must3r_size": must3r_size } class MOSEDataset(Dataset): def __init__( self, data_root: str, img_mean = (0.485, 0.456, 0.406), img_std = (0.229, 0.224, 0.225), N: int = 8, image_size: int = 1024, verbose = False, max_stride = 2, dataset_scale = 1, valid_must3r_sizes = [224, 512] ): self.verbose = verbose self.data_root = data_root self.dataset_scale = dataset_scale self.N = N self.max_stride = max_stride self.image_transform = T.Compose([ T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), T.Normalize(mean = img_mean, std = img_std) ]) self.instance_transform = T.Compose([ T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), ]) self.valid_must3r_sizes = valid_must3r_sizes self.videos = os.listdir(os.path.join(data_root, 'JPEGImages')) self.frames = {} self.masks = {} self.indices = [] for video in tqdm(self.videos): if not os.path.isdir(os.path.join(data_root, 'JPEGImages', video)): continue frames = sorted(glob(os.path.join(data_root, 'JPEGImages', video, '*.jpg')), key = lambda x: int(os.path.basename(x).split('.')[0])) masks = sorted(glob(os.path.join(data_root, 'Annotations', video, '*.png')), key = lambda x: int(os.path.basename(x).split('.')[0])) if len(frames) < self.N: if self.verbose: print(f"skip video {video} as not enough frames") continue assert len(frames) == len(masks) and len(frames) >= self.N, f'{len(frames)=}, {len(masks)=} in {video}' self.frames[video] = frames self.masks[video] = masks self.indices += [(video, idx) for idx in range(len(frames))] print(f'Found {len(self.indices)} frames, and {len(self.frames)} videos, with min length {min([len(self.frames[video]) for video in self.frames])} and max length {max([len(self.frames[video]) for video in self.frames])}') def __len__(self): return len(self.indices) * self.dataset_scale def __getitem__(self, idx): idx = idx % len(self.indices) video, idx = self.indices[idx] sampled_indices = np.arange(max(0, idx - self.N), idx).tolist() + np.arange(idx, min(len(self.frames[video]), idx + self.N * self.max_stride)).tolist() unique_ids = None while unique_ids is None or len(unique_ids) == 0: if unique_ids is not None: sampled_indices.pop(0) if len(sampled_indices) < self.N: return self[np.random.randint(len(self))] unique_ids, counts = np.unique(np.array(Image.open(self.masks[video][sampled_indices[0]])), return_counts = True) unique_ids = unique_ids[(unique_ids != 0) & (counts > counts.sum() * 0.01)] sampled_indices = sampled_indices[::len(sampled_indices) // self.N][:self.N] assert len(unique_ids) > 0 and len(sampled_indices) == self.N filelist = [self.frames[video][idx] for idx in sampled_indices] must3r_size = np.random.choice(self.valid_must3r_sizes).item() views, resize_funcs = load_images(filelist, size = must3r_size, patch_size = 16, verbose = self.verbose) original_instances = [] original_imgs = [] for frame_idx, (resize_func, sample_idx) in enumerate(zip(resize_funcs, sampled_indices)): assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}' # assert resize_func.transforms[0].size[0] > resize_func.transforms[1].size[0], f'Expected first transform to be larger than second, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' # assert resize_func.transforms[0].size[1] / resize_func.transforms[1].size[1] == resize_func.transforms[0].size[0] / resize_func.transforms[1].size[0], f'Expected aspect ratio to be preserved, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' if frame_idx == 0: for instance_id in np.random.permutation(unique_ids).tolist() + [None]: if instance_id is None: return self[np.random.randint(len(self))] if (resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id)).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.01): break original_instances.append(resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id)) original_imgs.append(resize_func.transforms[0](TF.to_tensor(Image.open(self.frames[video][sample_idx])))) original_instances = torch.stack(original_instances).squeeze()[:, None] instances = self.instance_transform(original_instances) assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}' original_imgs = torch.stack(original_imgs) imgs = self.image_transform(original_imgs) return { 'original_images': original_imgs, 'images': imgs, 'original_masks': original_instances, 'masks': instances, 'filelist': filelist, 'must3r_views': views, 'video': video, 'instance_id': instance_id, 'dataset': 'mose', 'valid_masks': torch.ones_like(instances), 'must3r_size': must3r_size, } # Reads a Ground truth trajectory file def read_trajectory_file(filepath): def _transform_from_Rt(R, t): M = np.identity(4) M[:3, :3] = R M[:3, 3] = t return M # Reads a Ground truth trajectory line def _read_trajectory_line(line): line = line.rstrip().split(",") pose = {} pose["timestamp"] = int(line[1]) translation = np.array([float(p) for p in line[3:6]]) quat_xyzw = np.array([float(o) for o in line[6:10]]) rot_matrix = Rotation.from_quat(quat_xyzw).as_matrix() rot_matrix = np.array(rot_matrix) pose["position"] = translation pose["rotation"] = rot_matrix pose["transform"] = _transform_from_Rt(rot_matrix, translation) return pose assert os.path.exists(filepath), f"Could not find trajectory file: {filepath}" with open(filepath, "r") as f: _ = f.readline() # header positions = [] rotations = [] transforms = [] timestamps = [] for line in f.readlines(): pose = _read_trajectory_line(line) positions.append(pose["position"]) rotations.append(pose["rotation"]) transforms.append(pose["transform"]) timestamps.append(pose["timestamp"]) positions = np.stack(positions) rotations = np.stack(rotations) transforms = np.stack(transforms) timestamps = np.array(timestamps) return { "ts": positions, "Rs": rotations, "Ts_world_from_device": transforms, "timestamps": timestamps, } from projectaria_tools.core import calibration from projectaria_tools.core.image import InterpolationMethod class ASEDataset(Dataset): def __init__( self, data_root: str, img_mean = (0.485, 0.456, 0.406), img_std = (0.229, 0.224, 0.225), N: int = 8, image_size: int = 1024, verbose = False, dataset_scale = 1, continuous_prob = 0, invalid_classes = ['ceiling', 'wall', 'empty_space', 'background', 'floor', 'window'], valid_must3r_sizes = [224, 512] ): self.verbose = verbose self.data_root = data_root self.dataset_scale = dataset_scale self.continuous_prob = continuous_prob self.N = N self.image_transform = T.Compose([ T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), T.Normalize(mean = img_mean, std = img_std) ]) self.instance_transform = T.Compose([ T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), ]) self.valid_must3r_sizes = valid_must3r_sizes from projectaria_tools.projects import ase from projectaria_tools.core import calibration self.ase_device = ase.get_ase_rgb_calibration() self.ase_width, self.ase_height = self.ase_device.get_image_size() assert self.ase_width == self.ase_height, f"Expected square images, got {self.ase_width}x{self.ase_height}" self.ase_pinhole = calibration.get_linear_camera_calibration( self.ase_width, self.ase_height, 320, "camera-rgb", self.ase_device.get_transform_device_camera() ) self.fx, self.fy = self.ase_pinhole.get_focal_lengths() self.cx, self.cy = self.ase_pinhole.get_principal_point() self.K = np.array([[self.fx, 0, self.cx], [0, self.fy, self.cy], [0, 0, 1 ]], dtype = np.float32) self.videos = os.listdir(os.path.join(data_root)) self.frames = {} self.masks = {} self.must3r_feats = {} self.appearances = {} self.mask2indices = {} self.validindices = {} self.indices = [] for video in tqdm(self.videos, desc='Loading ASE videos'): if not os.path.isdir(os.path.join(data_root, video)): print(f"skip {video} as not a directory") continue frames = sorted(glob(os.path.join(data_root, video, 'undistorted', '*.jpg'))) masks = sorted(glob(os.path.join(data_root, video, 'undistorted-instances', '*.png'))) must3r_feats = sorted(glob(os.path.join(data_root, video, 'must3r-features', '*.pt'))) if not (len(must3r_feats) == len(frames) == len(masks)): if self.verbose: print(f"skip {video} as {len(must3r_feats)=}, {len(frames)=}, {len(masks)=} in {video}") continue assert all([os.path.splitext(os.path.basename(must3r_feat))[0] == os.path.splitext(os.path.basename(frame))[0] for must3r_feat, frame in zip(must3r_feats, frames)]), f'Must3r features and frames do not match in {video}' if len(frames) < self.N: if self.verbose: print(f"skip video {video} as not enough frames") continue self.frames[video] = frames self.masks[video] = masks self.must3r_feats[video] = must3r_feats self.appearances[video] = json.load(open(os.path.join(data_root, video, 'instances-appearances.json'))) self.mask2indices[video] = { os.path.basename(m): i for i, m in enumerate(masks) } self.indices += [(video, idx) for idx in range(len(frames) - self.N + 1)] self.validindices[video] = [int(instance_id) for instance_id, class_name in json.load(open(os.path.join(data_root, video, 'object_instances_to_classes.json'))).items() if class_name not in invalid_classes] # if os.path.exists(os.path.join(data_root, video, 'object_instances_to_classes.json')) else None print(f'Found {len(self.indices)} frames, and {len(self.frames)} videos, with min length {min([len(self.frames[video]) for video in self.frames])} and max length {max([len(self.frames[video]) for video in self.frames])} and {sum([(len(ids) if ids is not None else 0) for ids in self.validindices.values()])} valid instances') self._log_path = "./ase_dataset_resample.log" def __len__(self): return len(self.indices) * self.dataset_scale def __getitem__(self, idx): idx = idx % len(self.indices) video, idx = self.indices[idx] ## 1. Randomly shuffle frames choices = np.delete(np.arange(len(self.frames[video]) - self.N + 1), idx) sampled_indices = [idx] + np.random.choice(choices, size = len(choices), replace = False).tolist() ## 2. Find unique instance IDs in the first frame unique_ids = None while unique_ids is None or len(unique_ids) == 0: if unique_ids is not None: sampled_indices.pop(0) if len(sampled_indices) < self.N: return self[np.random.randint(len(self))] unique_ids = np.unique(np.array(Image.open(self.masks[video][sampled_indices[0]])), return_counts = False) unique_ids = unique_ids[(unique_ids != 0) & np.array([class_id in self.validindices[video] for class_id in unique_ids])] # if self.validindices[video] is not None else True first_frame_idx = sampled_indices[0] assert len(unique_ids) > 0 ## 3. Load the resize funcs of the first frame feat_len = torch.load(self.must3r_feats[video][first_frame_idx], map_location = 'cpu')[-1].shape[-2] must3r_size = original_must3r_size = (224 if feat_len == 196 else 512) is_continuous = (np.random.rand() < self.continuous_prob) or original_must3r_size not in self.valid_must3r_sizes if is_continuous: must3r_size = np.random.choice(self.valid_must3r_sizes).item() _, [resize_func] = load_images([self.frames[video][first_frame_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose) assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}' assert must3r_size != original_must3r_size or resize_func.transforms[1].size[0] * resize_func.transforms[1].size[1] == feat_len * 256, f'Expected {resize_func.transforms[1].size[0]}x{resize_func.transforms[1].size[1]} to be {feat_len * 256}, got {feat_len}' for instance_id in np.random.permutation(unique_ids).tolist() + [None]: if instance_id is None: return self[np.random.randint(len(self))] if (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.2) > (resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][first_frame_idx]))) == instance_id)).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.01): break if is_continuous: sampled_indices = np.arange(first_frame_idx, min(len(self.frames[video]), first_frame_idx + self.N)).tolist() # sampled_indices += np.random.choice(first_frame_idx, size = first_frame_idx, replace = False).tolist() sampled_indices = sampled_indices[:self.N] assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}' else: sampled_indices = np.arange(first_frame_idx, len(self.frames[video])).tolist()[:2] sampled_indices = sorted(sampled_indices, key = lambda sample_idx: resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id).sum(), reverse = True) ## prioritize frames with larger masks first_frame_idx = sampled_indices[0] views, original_instances, original_imgs, filelist, extrinsics, depths, point_maps, fov_ratios = [], [], [], [], [], [], [], {} pre_sampled_len = len(sampled_indices) if len(sampled_indices) < self.N: instance_appearance_candidates = set([self.mask2indices[video][p] for p in self.appearances[video][str(instance_id)]]) - set(sampled_indices) sampled_indices += np.random.permutation(list(instance_appearance_candidates)).tolist() sampled_indices += np.random.permutation(list(set(np.arange(len(self.frames[video])).tolist()) - set(instance_appearance_candidates) - set(sampled_indices))).tolist() trajectory = read_trajectory_file(os.path.join(self.data_root, video, 'trajectory.csv')) while len(views) < self.N and len(sampled_indices) >= self.N: sample_idx = sampled_indices[len(views)] [view], [resize_func] = load_images([self.frames[video][sample_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose) instance_map = resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx])) == instance_id)) if len(views) >= pre_sampled_len and not (instance_map.shape[-1] * instance_map.shape[-2] * 0.005 < instance_map.sum() < instance_map.shape[-1] * instance_map.shape[-2] * 0.25): sampled_indices.pop(len(views)) continue extrinsic = trajectory['Ts_world_from_device'][sample_idx] @ self.ase_pinhole.get_transform_device_camera().to_matrix() depth = calibration.distort_by_calibration( np.array(Image.open(self.frames[video][sample_idx].replace('undistorted', 'depth').replace('vignette', 'depth').replace('.jpg', '.png'))), self.ase_pinhole, self.ase_device, InterpolationMethod.NEAREST_NEIGHBOR ).astype(np.float32) / 1000.0 point_map = resize_func.transforms[0](torch.rot90(torch.from_numpy(depth_to_world_pointmap(depth, extrinsic, self.K).astype(np.float32)).permute(2, 0, 1), k = -1, dims = (1, 2))) assert point_map.shape[-2] == instance_map.shape[-2], f"Expected height {instance_map.shape[-2]}, got {point_map.shape[-2]}" fov_ratio = None if len(views) < pre_sampled_len or instance_map.sum().item() == 0 or \ (fov_ratio := (in_fov_ratio(point_map[:, instance_map].permute(1, 0), extrinsics[0], K = self.K, W = self.ase_height, H = self.ase_width, ## for rot -90 W_crop = abs(int(self.ase_height) - original_instances[0].shape[-2]) // 2, H_crop = abs(int(self.ase_width) - original_instances[0].shape[-1]) // 2)[0])) > 0.25: views.append(view) original_instances.append(instance_map) original_imgs.append(resize_func.transforms[0](TF.to_tensor(Image.open(self.frames[video][sample_idx])))) filelist.append(self.frames[video][sample_idx]) extrinsics.append(extrinsic) depths.append(resize_func.transforms[0](torch.rot90(torch.from_numpy(depth), k = -1, dims = (0, 1)))) point_maps.append(point_map) fov_ratios[self.frames[video][sample_idx]] = fov_ratio if fov_ratio is not None else -1 else: sampled_indices.pop(len(views)) continue sampled_indices = sampled_indices[:len(views)] if len(sampled_indices) < self.N: open(self._log_path, "a").write(f"[short_span] {video}: span={len(sampled_indices)} < N={self.N}\n") return self[np.random.randint(len(self))] assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}' if not is_continuous or (np.random.rand() < 0.8 and must3r_size == original_must3r_size): assert original_must3r_size == must3r_size, f'If not continuous, must3r size should not change, got {must3r_size} and {original_must3r_size}' must3r_feats_filelist = [self.must3r_feats[video][idx] for idx in sampled_indices] must3r_feats = [torch.load(must3r_filepath, map_location = 'cpu') for must3r_filepath in must3r_feats_filelist] must3r_feats_head = torch.cat([f[-1] for f in must3r_feats], dim = 0) must3r_feats = [f[:-1] for f in must3r_feats] must3r_feats = [torch.cat(f, dim = 0) for f in zip(*must3r_feats)] must3r_feats = [ rearrange(f, 'b (h w) c -> b c h w', h = views[0]['true_shape'][0] // 16, w = views[0]['true_shape'][1] // 16) for f in must3r_feats ] else: assert is_continuous, f'If must3r size changed, should be continuous sampling, got {must3r_size} and {original_must3r_size}' must3r_feats = None must3r_feats_head = None original_instances = torch.stack(original_instances).squeeze()[:, None] instances = self.instance_transform(original_instances) assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}' original_imgs = torch.stack(original_imgs) imgs = self.image_transform(original_imgs) # if is_continuous: # permutation = torch.arange(len(instances)) # else: # permutation = torch.argsort(instances.squeeze().sum(dim = (1, 2)), descending = True) permutation = torch.arange(len(instances)) permutation[pre_sampled_len:] = torch.randperm(len(instances) - pre_sampled_len) + pre_sampled_len return { 'original_images': original_imgs[permutation], 'images': imgs[permutation], 'original_masks': original_instances[permutation], 'masks': instances[permutation], 'filelist': [filelist[idx] for idx in permutation], 'must3r_views': [views[idx] for idx in permutation], 'must3r_size': must3r_size, 'video': video, 'instance_id': instance_id, 'dataset': 'scannetpp', 'valid_masks': torch.ones_like(instances), 'intrinsics': torch.from_numpy(self.K).unsqueeze(0).repeat(self.N, 1, 1)[permutation], 'extrinsics': torch.from_numpy(np.stack(extrinsics, axis = 0))[permutation], 'depths': torch.from_numpy(np.stack(depths, axis = 0))[permutation], 'point_maps': torch.from_numpy(np.stack(point_maps, axis = 0))[permutation], 'fov_ratios': fov_ratios, 'is_continuous': is_continuous } | ( { 'must3r_feats': [f[permutation] for f in must3r_feats], 'must3r_feats_head': must3r_feats_head[permutation], 'must3r_feats_filelist': [must3r_feats_filelist[idx] for idx in permutation], } if must3r_feats is not None else {} ) def pose_from_qwxyz_txyz(elems): qw, qx, qy, qz, tx, ty, tz = map(float, elems) pose = np.eye(4) pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() pose[:3, 3] = (tx, ty, tz) return np.linalg.inv(pose) # returns cam2world def depth_to_world_pointmap(depth, c2w, K, depth_type = 'range'): """ depth: (H,W) depth in meters, camera-Z c2w: (4,4) camera-to-world transform K: (3,3) camera intrinsics Returns: (H,W,3) world xyz (NaN for invalid depth) """ Kinv = np.linalg.inv(K) H_, W_ = depth.shape ys, xs = np.meshgrid(np.arange(H_), np.arange(W_), indexing='ij') ones = np.ones_like(xs, dtype=np.float64) pix = np.stack([xs, ys, ones], axis=-1).reshape(-1, 3).T # (3,N) rays_cam = Kinv @ pix # (3,N) z = depth.reshape(-1) # (N,) if depth_type == 'range': rays_cam = rays_cam / np.linalg.norm(rays_cam, axis = 0, keepdims = True) # (3,N) elif depth_type == 'z-buf': pass else: raise ValueError(f'Unknown depth_type {depth_type}') xyz_cam = rays_cam * z # scale each ray by depth xyz_cam_h = np.vstack([xyz_cam, np.ones_like(z)]) # (4,N) xyz_w_h = c2w @ xyz_cam_h # (4,N) xyz_w = xyz_w_h[:3].T.reshape(H_, W_, 3) mask = (depth <= 0) | ~np.isfinite(depth) xyz_w[mask] = np.nan return xyz_w def in_fov_ratio(points, c2w, K, H, W, H_crop, W_crop): """ points: (N,3) world coords, torch tensor c2w: (4,4) camera-to-world, torch tensor K: (3,3) intrinsics, torch tensor H,W: image size """ # device = points.device K = K # .to(device) # world -> camera w2c = np.linalg.inv(c2w) # .to(device) Pc = (points @ w2c[:3, :3].T) + w2c[:3, 3] X, Y, Z = Pc[:,0], Pc[:,1], Pc[:,2] # projection u = K[0, 0] * (X / Z) + K[0, 2] v = K[1, 1] * (Y / Z) + K[1, 2] mask = (Z > 0) & (u >= W_crop) & (u < W - W_crop) & (v >= H_crop) & (v < H - H_crop) return mask.float().mean(), mask class ScanNetPPV2Dataset(Dataset): def __init__( self, data_root: str, must3r_data_root: str = None, img_mean = (0.485, 0.456, 0.406), img_std = (0.229, 0.224, 0.225), N: int = 8, image_size: int = 1024, verbose = False, dataset_scale = 1, continuous_prob = 0, instance_classes_file = '/metadata/semantic_benchmark/top100_instance.txt', split_file: str = '/splits/nvs_sem_train.txt', excluding_scenes = ["09d6e808b4", "0f69aefe3d", "1b379f1114", "1cbb105c6a", "2c7c10379b", "46638cfd0f", "4f341f3af0", "6ef2ac745a", "898a7dfd0c", "aa852f7871", "eea4ad9c04", 'd27235711b'], ## horizontal / vertical flip issues valid_must3r_sizes = [224, 512] ): self.verbose = verbose self.data_root = data_root self.must3r_data_root = must3r_data_root if must3r_data_root is not None else data_root self.dataset_scale = dataset_scale self.excluding_scenes = excluding_scenes self.instance_classes = open(instance_classes_file).read().splitlines() self.valid_scene_names = open(split_file).read().splitlines() self.continuous_prob = continuous_prob self.N = N self.image_transform = T.Compose([ T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), T.Normalize(mean = img_mean, std = img_std) ]) self.instance_transform = T.Compose([ T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), ]) self.valid_must3r_sizes = valid_must3r_sizes self.videos = os.listdir(os.path.join(data_root)) self.frames = {} self.masks = {} self.must3r_feats = {} self.appearances = {} self.id2label_name = {} self.intrinsics = {} self.extrinsics = {} self.indices = [] self._log_path = "./scannetppv2_dataset_resample.log" for video in tqdm(self.videos, desc = 'Loading ScanNet++V2 videos'): if video not in self.valid_scene_names or video in self.excluding_scenes: if self.verbose: print(f"skip {video} as not in split or excluded") continue if not os.path.isdir(os.path.join(data_root, video)): print(f"skip {video} as not a directory") continue if video in ['46638cfd0f']: if self.verbose: print(f"skip {video} as broken") continue masks = sorted(glob(os.path.join(self.data_root, video, 'iphone', 'render_instance', '*.png'))) if len(masks) == 0: if self.verbose: print(f"skip {video} as no masks found") continue frames = [m.replace('render_instance', 'rgb').replace('.png', '.jpg') for m in masks] must3r_feats = [m.replace(self.data_root, self.must3r_data_root).replace('iphone/render_instance', 'must3r-features').replace('.png', '.pt') for m in masks] if not all([os.path.exists(p) for p in must3r_feats[:1]]): if self.verbose: print(f"skip {video} as not all must3r features or frames exist") continue # assert all([os.path.exists(p) for p in frames]), f'Not all frames exist in {video}' self.frames[video] = frames self.masks[video] = masks self.must3r_feats[video] = must3r_feats self.appearances[video] = json.loads(open(os.path.join(data_root, video, 'scans/instance-appearances.json')).read()) self.intrinsics[video] = self.load_intrinsics(os.path.join(data_root, video, 'iphone', 'colmap', 'cameras.txt')) assert len(self.intrinsics[video]) == 1, f'Expected 1 camera, got {len(self.intrinsics[video])} in {video}' self.extrinsics[video] = os.path.join(data_root, video, 'iphone', 'colmap', 'images.txt') assert all([f_name == os.path.basename(m) for f_name, m in zip(self.appearances[video]['framenames'], self.masks[video])]), f'Frame names in appearances do not match masks in {video}' self.id2label_name[video] = json.loads(open(os.path.join(data_root, video, 'scans/instance_id2label_name.json')).read()) self.indices += [(video, idx) for idx in range(len(frames) - self.N + 1)] print(f'Found {len(self.indices)} frames, and {len(self.frames)} videos, with min length {min([len(self.frames[video]) for video in self.frames])} and max length {max([len(self.frames[video]) for video in self.frames])}') def load_intrinsics(self, path): with open(path, 'r') as f: raw = f.read().splitlines()[3:] # skip header intrinsics = {} for camera in tqdm(raw, position = 1, leave = False): camera = camera.split(' ') intrinsics[int(camera[0])] = [camera[1]] + [float(cam) for cam in camera[2:]] return intrinsics def __len__(self): return len(self.indices) * self.dataset_scale def __getitem__(self, idx): idx = idx % len(self.indices) video, idx = self.indices[idx] if len(glob(os.path.join(self.data_root, video, 'iphone/depth/*.png'))) == 0: return self[np.random.randint(len(self))] ## 1. Randomly shuffle frames choices = np.delete(np.arange(len(self.frames[video]) - self.N + 1), idx) sampled_indices = [idx] + np.random.choice(choices, size = len(choices), replace = False).tolist() ## 2. Find unique instance IDs in the first frame unique_ids = None while unique_ids is None or len(unique_ids) == 0: if unique_ids is not None: sampled_indices.pop(0) if len(sampled_indices) == 0: return self[np.random.randint(len(self))] unique_ids, _ = np.unique(np.array(Image.open(self.masks[video][sampled_indices[0]])), return_counts = True) unique_ids = unique_ids[np.array([class_id not in [0, 65535] and self.id2label_name[video][str(class_id)] in self.instance_classes and all([s not in self.id2label_name[video][str(class_id)].lower() for s in ['wall', 'floor', 'ceiling', 'window', 'curtain', 'blind', 'table']]) for class_id in unique_ids])] first_frame_idx = sampled_indices[0] assert len(unique_ids) > 0 ## 3. Load the resize funcs of the first frame feat_len = torch.load(self.must3r_feats[video][first_frame_idx], map_location = 'cpu')[-1].shape[-2] must3r_size = original_must3r_size = (224 if feat_len == 196 else 512) is_continuous = (np.random.rand() < self.continuous_prob) or original_must3r_size not in self.valid_must3r_sizes if is_continuous: must3r_size = np.random.choice(self.valid_must3r_sizes).item() _, [resize_func] = load_images([self.frames[video][first_frame_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose) assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}' # assert resize_func.transforms[0].size[0] > resize_func.transforms[1].size[0], f'Expected first transform to be larger than second, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' # assert resize_func.transforms[0].size[1] / resize_func.transforms[1].size[1] == resize_func.transforms[0].size[0] / resize_func.transforms[1].size[0], f'Expected aspect ratio to be preserved, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' assert must3r_size != original_must3r_size or resize_func.transforms[1].size[0] * resize_func.transforms[1].size[1] == feat_len * 256, f'Expected {resize_func.transforms[1].size[0]}x{resize_func.transforms[1].size[1]} to be {feat_len * 256}, got {feat_len}' for instance_id in np.random.permutation(unique_ids).tolist() + [None]: if instance_id is None: return self[np.random.randint(len(self))] if (resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][first_frame_idx]))) == instance_id)).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.01): break if is_continuous: sampled_indices = np.arange(first_frame_idx, len(self.frames[video])).tolist() # sampled_indices += np.random.permutation(list(set(np.arange(len(self.frames[video])).tolist()) - set(self.appearances[video][str(instance_id)]) - set(sampled_indices))).tolist() sampled_indices = sampled_indices[:self.N] assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}' else: sampled_indices = np.arange(first_frame_idx, len(self.frames[video])).tolist()[:2] sampled_indices = sorted(sampled_indices, key = lambda sample_idx: resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id).sum(), reverse = True) ## prioritize frames with larger masks first_frame_idx = sampled_indices[0] raw_poses = { raw.split()[-1].split('iphone/')[-1].split('video/')[-1]: raw.split()[1:-1] for raw in open(self.extrinsics[video], 'r').read().splitlines() if (not raw.startswith('#')) and len(raw.split()) > 0 } views, original_instances, original_imgs, filelist, extrinsics, raw_intrinsics, intrinsics, depths, point_maps, fov_ratios = [], [], [], [], [], [], [], [], [], {} pre_sampled_len = len(sampled_indices) if len(sampled_indices) < self.N: sampled_indices = sampled_indices + np.random.permutation(list(set(self.appearances[video][self.id2label_name[video][str(instance_id)]]) - set(sampled_indices))).tolist() + \ np.random.permutation(list(set(np.arange(len(self.frames[video])).tolist()) - set(self.appearances[video][self.id2label_name[video][str(instance_id)]]) - set(sampled_indices))).tolist() while len(views) < self.N and len(sampled_indices) >= self.N: sample_idx = sampled_indices[len(views)] [view], [resize_func] = load_images([self.frames[video][sample_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose) instance_map = resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx])) == instance_id)) if len(views) >= pre_sampled_len and (0 < instance_map.sum() < instance_map.shape[-1] * instance_map.shape[-2] * 0.01): sampled_indices.pop(len(views)) continue f_name = os.path.basename(self.frames[video][sample_idx]) extrinsic = pose_from_qwxyz_txyz(raw_poses[f_name][:-1]) raw_intrinsic = self.intrinsics[video][int(raw_poses[f_name][-1])] intrinsic = np.array([[raw_intrinsic[3], 0, raw_intrinsic[5]], [0, raw_intrinsic[4], raw_intrinsic[6]], [0, 0, 1 ]], dtype = np.float32) depth = np.array(Image.open(self.frames[video][sample_idx].replace('rgb', 'depth').replace('.jpg', '.png')).resize((int(raw_intrinsic[1]), int(raw_intrinsic[2]))), dtype = np.float32) / 1000.0 point_map = resize_func.transforms[0](torch.from_numpy(depth_to_world_pointmap(depth, extrinsic, intrinsic).astype(np.float32)).permute(2, 0, 1)) assert point_map.shape[-2] == instance_map.shape[-2] == int(raw_intrinsic[2]), f'Expected height {int(raw_intrinsic[2])}, got {point_map.shape[-2]} and {instance_map.shape[-2]}' fov_ratio = None if len(views) < pre_sampled_len or instance_map.sum().item() == 0 or \ (fov_ratio := (in_fov_ratio(point_map[:, instance_map].permute(1, 0), extrinsics[0], K = intrinsics[0], H = int(raw_intrinsics[0][2]), W = int(raw_intrinsics[0][1]), H_crop = abs(int(raw_intrinsics[0][2]) - original_instances[0].shape[-2]) // 2, W_crop = abs(int(raw_intrinsics[0][1]) - original_instances[0].shape[-1]) // 2)[0])) > 0.25: views.append(view) original_instances.append(instance_map) original_imgs.append(resize_func.transforms[0](TF.to_tensor(Image.open(self.frames[video][sample_idx])))) filelist.append(self.frames[video][sample_idx]) extrinsics.append(extrinsic) raw_intrinsics.append(raw_intrinsic) intrinsics.append(intrinsic) depths.append(resize_func.transforms[0](torch.from_numpy(depth))) point_maps.append(point_map) fov_ratios[self.frames[video][sample_idx]] = fov_ratio if fov_ratio is not None else -1 else: sampled_indices.pop(len(views)) continue sampled_indices = sampled_indices[:len(views)] if len(sampled_indices) < self.N: open(self._log_path, "a").write(f"[short_span] {video}: span={len(sampled_indices)} < N={self.N}\n") return self[np.random.randint(len(self))] assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}' if not is_continuous or (np.random.rand() < 0.8 and must3r_size == original_must3r_size): assert original_must3r_size == must3r_size, f'If not continuous, must3r size should not change, got {must3r_size} and {original_must3r_size}' must3r_feats_filelist = [self.must3r_feats[video][idx] for idx in sampled_indices] must3r_feats = [torch.load(must3r_filepath, map_location = 'cpu') for must3r_filepath in must3r_feats_filelist] must3r_feats_head = torch.cat([f[-1] for f in must3r_feats], dim = 0) must3r_feats = [f[:-1] for f in must3r_feats] must3r_feats = [torch.cat(f, dim = 0) for f in zip(*must3r_feats)] must3r_feats = [ rearrange(f, 'b (h w) c -> b c h w', h = views[0]['true_shape'][0] // 16, w = views[0]['true_shape'][1] // 16) for f in must3r_feats ] else: assert is_continuous, f'If must3r size changed, should be continuous sampling, got {must3r_size} and {original_must3r_size}' must3r_feats = None must3r_feats_head = None original_instances = torch.stack(original_instances).squeeze()[:, None] instances = self.instance_transform(original_instances) assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}' # assert instances[1:].sum() == 0, f"Only first frame should have the instance, got {instances.sum()=}" original_imgs = torch.stack(original_imgs) imgs = self.image_transform(original_imgs) # if is_continuous: # permutation = torch.arange(len(instances)) # else: # permutation = torch.argsort(instances.squeeze().sum(dim = (1, 2)), descending = True) permutation = torch.arange(len(instances)) permutation[pre_sampled_len:] = torch.randperm(len(instances) - pre_sampled_len) + pre_sampled_len return { 'original_images': original_imgs[permutation], 'images': imgs[permutation], 'original_masks': original_instances[permutation], 'masks': instances[permutation], 'filelist': [filelist[idx] for idx in permutation], 'must3r_views': [views[idx] for idx in permutation], 'must3r_size': must3r_size, 'video': video, 'instance_id': instance_id, 'dataset': 'scannetpp', 'valid_masks': torch.ones_like(instances), 'intrinsics': torch.from_numpy(np.stack(intrinsics, axis = 0))[permutation], 'extrinsics': torch.from_numpy(np.stack(extrinsics, axis = 0))[permutation], 'depths': torch.from_numpy(np.stack(depths, axis = 0))[permutation], 'point_maps': torch.from_numpy(np.stack(point_maps, axis = 0))[permutation], 'fov_ratios': fov_ratios, 'is_continuous': is_continuous, } | ( { 'must3r_feats': [f[permutation] for f in must3r_feats], 'must3r_feats_head': must3r_feats_head[permutation], 'must3r_feats_filelist': [must3r_feats_filelist[idx] for idx in permutation], } if must3r_feats is not None else {} )