Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import torchvision.transforms as T | |
| import torchvision.transforms.functional as TF | |
| import torch | |
| from tqdm import tqdm | |
| import os | |
| from glob import glob | |
| from torch.utils.data import Dataset | |
| from must3r.tools.image import get_resize_function | |
| from PIL import Image | |
| import numpy as np | |
| from einops import rearrange | |
| from typing import List, Dict, Optional, Tuple | |
| from pycocotools import mask as mask_utils | |
| import random, cv2 | |
| from scipy.spatial.transform import Rotation | |
| SAV_ANNOT_RATE = 4 # SA-V: annotations at 6 fps, video at 24 fps | |
| def load_images(folder_content, size, patch_size = 16, verbose = True): | |
| imgs = [] | |
| transform = ImgNorm = T.Compose([T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
| resize_funcs = [] | |
| for content in folder_content: | |
| if isinstance(content, str): | |
| if verbose: | |
| print(f'Loading image from {content} ', end = '') | |
| rgb_image = Image.open(content).convert('RGB') | |
| elif isinstance(content, Image.Image): | |
| rgb_image = content | |
| else: | |
| raise ValueError(f'Unknown content type: {type(content)}') | |
| rgb_image.load() | |
| W, H = rgb_image.size | |
| resize_func, _, to_orig = get_resize_function(size, patch_size, H, W) | |
| resize_funcs.append(resize_func) | |
| rgb_tensor = resize_func(transform(rgb_image)) | |
| imgs.append(dict(img=rgb_tensor, true_shape=np.int32([rgb_tensor.shape[-2], rgb_tensor.shape[-1]]))) | |
| if verbose: | |
| print(f'with resolution {W}x{H} --> {rgb_tensor.shape[-1]}x{rgb_tensor.shape[-2]}') | |
| return imgs, resize_funcs | |
| def _decode_rle(rle: Dict, h: int, w: int) -> np.ndarray: | |
| if not rle or "counts" not in rle: | |
| return np.zeros((h, w), dtype=np.uint8) | |
| counts = rle["counts"] | |
| if isinstance(counts, str): | |
| counts = counts.encode("utf-8") | |
| m = mask_utils.decode({"size": [h, w], "counts": counts}) | |
| return (np.asarray(m).squeeze() > 0) | |
| def _read_frame_rgb(cap: cv2.VideoCapture, idx: int, fallback_hw: Optional[Tuple[int,int]]=None) -> np.ndarray: | |
| ok = cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) | |
| if not ok: | |
| raise RuntimeError(f"cv2.VideoCapture.set({idx}) failed") | |
| else: | |
| ok, bgr = cap.read() | |
| return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) | |
| class SAVTrainDataset(Dataset): | |
| """ | |
| SA-V train Dataset (mp4 + {video_id}_{manual|auto}.json). | |
| Scans JSON with pattern: root/*/*.json (non-recursive). | |
| __getitem__ follows the requested 1–5 procedure. | |
| """ | |
| def __init__( | |
| self, | |
| data_root: str, | |
| mask_type: Optional[str] = None, # None | "manual" | "auto" | |
| img_mean = (0.485, 0.456, 0.406), | |
| img_std = (0.229, 0.224, 0.225), | |
| N: int = 8, | |
| image_size: int = 1024, | |
| verbose: bool = False, | |
| max_stride: int = 1, # kept for parity, not used in this flow | |
| dataset_scale: int = 32, | |
| area_thresh: float = 0.01, # area ratio threshold at original HxW | |
| valid_must3r_sizes = [224, 512] | |
| ): | |
| assert mask_type in (None, "manual", "auto") | |
| assert N >= 1 | |
| self.verbose = verbose | |
| self.data_root = data_root | |
| self.dataset_scale = int(dataset_scale) | |
| self.N = int(N) | |
| self.mask_type = mask_type | |
| self.area_thresh = float(area_thresh) | |
| self.max_stride = int(max_stride) | |
| self.valid_must3r_sizes = valid_must3r_sizes | |
| self.image_transform = T.Compose([ | |
| T.Resize((image_size, image_size), interpolation=T.InterpolationMode.NEAREST_EXACT), | |
| T.Normalize(mean=img_mean, std=img_std), | |
| ]) | |
| self.instance_transform = T.Compose([ | |
| T.Resize((image_size, image_size), interpolation=T.InterpolationMode.NEAREST_EXACT), | |
| ]) | |
| # --- collect through JSONs (non-recursive) --- | |
| json_paths = glob(os.path.join(data_root, "*", "*.json")) | |
| self.items: List[Tuple[str, str]] = [] # (vpath, jpath) | |
| for jpath in tqdm(json_paths, desc="scanning jsons"): | |
| base = os.path.splitext(os.path.basename(jpath))[0] | |
| # filter by mask_type if specified | |
| if self.mask_type is not None and not base.endswith(f"_{self.mask_type}"): | |
| continue | |
| if base.endswith("_manual"): | |
| vid = base[:-7] | |
| elif base.endswith("_auto"): | |
| vid = base[:-5] | |
| else: | |
| # strictly require suffix | |
| continue | |
| vpath = os.path.join(os.path.dirname(jpath), f"{vid}.mp4") | |
| if os.path.isfile(vpath): | |
| self.items.append((vpath, jpath)) | |
| print(f"Collected {len(self.items)} video-json pairs") | |
| self._log_path = "./sav_dataset_resample.log" | |
| def __len__(self): | |
| return self.dataset_scale * len(self.items) | |
| def _resample(self): | |
| return self[random.randrange(len(self))] | |
| def _log(self, msg: str): | |
| try: | |
| with open(self._log_path, "a") as f: | |
| f.write(msg.rstrip() + "\n") | |
| except Exception: | |
| pass | |
| def __getitem__(self, idx: int): | |
| vpath, jpath = self.items[idx % len(self.items)] | |
| # 1) load json | |
| with open(jpath, "r") as f: | |
| meta = json.load(f) | |
| masklet: List[List[Dict]] = meta.get("masklet", []) | |
| if not isinstance(masklet, list) or len(masklet) < self.N: | |
| self._log(f"[short_json] {jpath}: len(masklet)={len(masklet)} < N={self.N}") | |
| return self._resample() | |
| H, W = int(meta["video_height"]), int(meta["video_width"]) | |
| # 2) randomly sample a center frame idx in masklet, build sample_indices = [idx-N, idx+N] | |
| center = random.randrange(len(masklet)) | |
| left = max(0, center - self.N * self.max_stride) | |
| right = min(len(masklet), center + self.N * self.max_stride) | |
| sample_indices = list(range(left, right)) | |
| if len(sample_indices) < self.N: | |
| self._log(f"[short_span] {jpath}: span={len(sample_indices)} < N={self.N}") | |
| return self._resample() | |
| obj_order = None | |
| while True: | |
| if len(sample_indices) < self.N: | |
| self._log(f"[exhausted_span] {jpath}: remaining span < N; resample") | |
| return self._resample() | |
| f0 = sample_indices[0] | |
| rles = masklet[f0] if isinstance(masklet[f0], list) else [] | |
| if len(rles) == 0: | |
| # no objects at this frame, pop and continue | |
| sample_indices.pop(0) | |
| continue | |
| obj_order = list(range(len(rles))) | |
| random.shuffle(obj_order) | |
| has_valid_id = False | |
| for oid in obj_order: | |
| m = _decode_rle(rles[oid], H, W) | |
| area = int(m.sum()) | |
| if area <= 0: | |
| continue | |
| ratio = area / float(H * W + 1e-6) | |
| if ratio >= self.area_thresh: | |
| has_valid_id = True | |
| break | |
| if has_valid_id: | |
| break | |
| else: | |
| # tried all object indices, none passed; pop first frame and continue | |
| sample_indices.pop(0) | |
| # downsample sample_indices to exactly N | |
| sample_indices = sample_indices[::min(len(sample_indices) // self.N, self.max_stride)][:self.N] | |
| assert len(sample_indices) == self.N | |
| # 5) similar to MOSE dataset: read frames, build masks only at anchor frame | |
| cap = cv2.VideoCapture(vpath) | |
| frames_rgb = [] | |
| frame_indices_24 = [] | |
| for f_annot in sample_indices: | |
| f24 = int(f_annot * SAV_ANNOT_RATE) | |
| frames_rgb.append(_read_frame_rgb(cap, f24, fallback_hw=(H, W))) | |
| frame_indices_24.append(f24) | |
| cap.release() | |
| # build original_images tensor [N, 3, H, W] | |
| original_imgs_pil = [Image.fromarray(fr) for fr in frames_rgb] | |
| # must3r parity fields | |
| must3r_size = np.random.choice(self.valid_must3r_sizes).item() | |
| views, resize_funcs = load_images(original_imgs_pil, size = must3r_size, patch_size = 16, verbose = self.verbose) | |
| original_instances = [] | |
| original_imgs = [] | |
| for frame_idx, (resize_func, sample_idx) in enumerate(zip(resize_funcs, sample_indices)): | |
| assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}' | |
| # assert resize_func.transforms[0].size[0] > resize_func.transforms[1].size[0], f'Expected first transform to be larger than second, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' | |
| # assert resize_func.transforms[0].size[1] / resize_func.transforms[1].size[1] == resize_func.transforms[0].size[0] / resize_func.transforms[1].size[0], f'Expected aspect ratio to be preserved, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' | |
| if frame_idx == 0: | |
| for instance_id in obj_order + [None]: | |
| if instance_id is None: | |
| return self._resample() | |
| if (resize_func.transforms[0](torch.from_numpy(_decode_rle(masklet[sample_idx][instance_id], H, W))).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * self.area_thresh)): | |
| break | |
| original_instances.append(resize_func.transforms[0](torch.from_numpy(_decode_rle(masklet[sample_idx][instance_id], H, W)))) | |
| original_imgs.append(resize_func.transforms[0](TF.to_tensor(original_imgs_pil[frame_idx]))) | |
| original_instances = torch.stack(original_instances).squeeze()[:, None] | |
| instances = self.instance_transform(original_instances) | |
| assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}' | |
| original_imgs = torch.stack(original_imgs) | |
| imgs = self.image_transform(original_imgs) | |
| return { | |
| "original_images": original_imgs, # [N,3,H,W] | |
| "images": imgs, # [N,3,S,S] | |
| "original_masks": original_instances, # [N,1,H,W] | |
| "masks": instances, # [N,1,S,S] | |
| "filelist": sample_indices, | |
| "must3r_views": views, | |
| "video": os.path.splitext(os.path.basename(vpath))[0], | |
| "instance_id": int(instance_id), | |
| "dataset": "sav", | |
| "valid_masks": torch.ones_like(instances), # [N,1,S,S] | |
| "must3r_size": must3r_size | |
| } | |
| class MOSEDataset(Dataset): | |
| def __init__( | |
| self, | |
| data_root: str, | |
| img_mean = (0.485, 0.456, 0.406), | |
| img_std = (0.229, 0.224, 0.225), | |
| N: int = 8, | |
| image_size: int = 1024, | |
| verbose = False, | |
| max_stride = 2, | |
| dataset_scale = 1, | |
| valid_must3r_sizes = [224, 512] | |
| ): | |
| self.verbose = verbose | |
| self.data_root = data_root | |
| self.dataset_scale = dataset_scale | |
| self.N = N | |
| self.max_stride = max_stride | |
| self.image_transform = T.Compose([ | |
| T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), | |
| T.Normalize(mean = img_mean, std = img_std) | |
| ]) | |
| self.instance_transform = T.Compose([ | |
| T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), | |
| ]) | |
| self.valid_must3r_sizes = valid_must3r_sizes | |
| self.videos = os.listdir(os.path.join(data_root, 'JPEGImages')) | |
| self.frames = {} | |
| self.masks = {} | |
| self.indices = [] | |
| for video in tqdm(self.videos): | |
| if not os.path.isdir(os.path.join(data_root, 'JPEGImages', video)): | |
| continue | |
| frames = sorted(glob(os.path.join(data_root, 'JPEGImages', video, '*.jpg')), key = lambda x: int(os.path.basename(x).split('.')[0])) | |
| masks = sorted(glob(os.path.join(data_root, 'Annotations', video, '*.png')), key = lambda x: int(os.path.basename(x).split('.')[0])) | |
| if len(frames) < self.N: | |
| if self.verbose: | |
| print(f"skip video {video} as not enough frames") | |
| continue | |
| assert len(frames) == len(masks) and len(frames) >= self.N, f'{len(frames)=}, {len(masks)=} in {video}' | |
| self.frames[video] = frames | |
| self.masks[video] = masks | |
| self.indices += [(video, idx) for idx in range(len(frames))] | |
| print(f'Found {len(self.indices)} frames, and {len(self.frames)} videos, with min length {min([len(self.frames[video]) for video in self.frames])} and max length {max([len(self.frames[video]) for video in self.frames])}') | |
| def __len__(self): | |
| return len(self.indices) * self.dataset_scale | |
| def __getitem__(self, idx): | |
| idx = idx % len(self.indices) | |
| video, idx = self.indices[idx] | |
| sampled_indices = np.arange(max(0, idx - self.N), idx).tolist() + np.arange(idx, min(len(self.frames[video]), idx + self.N * self.max_stride)).tolist() | |
| unique_ids = None | |
| while unique_ids is None or len(unique_ids) == 0: | |
| if unique_ids is not None: | |
| sampled_indices.pop(0) | |
| if len(sampled_indices) < self.N: | |
| return self[np.random.randint(len(self))] | |
| unique_ids, counts = np.unique(np.array(Image.open(self.masks[video][sampled_indices[0]])), return_counts = True) | |
| unique_ids = unique_ids[(unique_ids != 0) & (counts > counts.sum() * 0.01)] | |
| sampled_indices = sampled_indices[::len(sampled_indices) // self.N][:self.N] | |
| assert len(unique_ids) > 0 and len(sampled_indices) == self.N | |
| filelist = [self.frames[video][idx] for idx in sampled_indices] | |
| must3r_size = np.random.choice(self.valid_must3r_sizes).item() | |
| views, resize_funcs = load_images(filelist, size = must3r_size, patch_size = 16, verbose = self.verbose) | |
| original_instances = [] | |
| original_imgs = [] | |
| for frame_idx, (resize_func, sample_idx) in enumerate(zip(resize_funcs, sampled_indices)): | |
| assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}' | |
| # assert resize_func.transforms[0].size[0] > resize_func.transforms[1].size[0], f'Expected first transform to be larger than second, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' | |
| # assert resize_func.transforms[0].size[1] / resize_func.transforms[1].size[1] == resize_func.transforms[0].size[0] / resize_func.transforms[1].size[0], f'Expected aspect ratio to be preserved, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' | |
| if frame_idx == 0: | |
| for instance_id in np.random.permutation(unique_ids).tolist() + [None]: | |
| if instance_id is None: | |
| return self[np.random.randint(len(self))] | |
| if (resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id)).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.01): | |
| break | |
| original_instances.append(resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id)) | |
| original_imgs.append(resize_func.transforms[0](TF.to_tensor(Image.open(self.frames[video][sample_idx])))) | |
| original_instances = torch.stack(original_instances).squeeze()[:, None] | |
| instances = self.instance_transform(original_instances) | |
| assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}' | |
| original_imgs = torch.stack(original_imgs) | |
| imgs = self.image_transform(original_imgs) | |
| return { | |
| 'original_images': original_imgs, | |
| 'images': imgs, | |
| 'original_masks': original_instances, | |
| 'masks': instances, | |
| 'filelist': filelist, | |
| 'must3r_views': views, | |
| 'video': video, | |
| 'instance_id': instance_id, | |
| 'dataset': 'mose', | |
| 'valid_masks': torch.ones_like(instances), | |
| 'must3r_size': must3r_size, | |
| } | |
| # Reads a Ground truth trajectory file | |
| def read_trajectory_file(filepath): | |
| def _transform_from_Rt(R, t): | |
| M = np.identity(4) | |
| M[:3, :3] = R | |
| M[:3, 3] = t | |
| return M | |
| # Reads a Ground truth trajectory line | |
| def _read_trajectory_line(line): | |
| line = line.rstrip().split(",") | |
| pose = {} | |
| pose["timestamp"] = int(line[1]) | |
| translation = np.array([float(p) for p in line[3:6]]) | |
| quat_xyzw = np.array([float(o) for o in line[6:10]]) | |
| rot_matrix = Rotation.from_quat(quat_xyzw).as_matrix() | |
| rot_matrix = np.array(rot_matrix) | |
| pose["position"] = translation | |
| pose["rotation"] = rot_matrix | |
| pose["transform"] = _transform_from_Rt(rot_matrix, translation) | |
| return pose | |
| assert os.path.exists(filepath), f"Could not find trajectory file: {filepath}" | |
| with open(filepath, "r") as f: | |
| _ = f.readline() # header | |
| positions = [] | |
| rotations = [] | |
| transforms = [] | |
| timestamps = [] | |
| for line in f.readlines(): | |
| pose = _read_trajectory_line(line) | |
| positions.append(pose["position"]) | |
| rotations.append(pose["rotation"]) | |
| transforms.append(pose["transform"]) | |
| timestamps.append(pose["timestamp"]) | |
| positions = np.stack(positions) | |
| rotations = np.stack(rotations) | |
| transforms = np.stack(transforms) | |
| timestamps = np.array(timestamps) | |
| return { | |
| "ts": positions, | |
| "Rs": rotations, | |
| "Ts_world_from_device": transforms, | |
| "timestamps": timestamps, | |
| } | |
| from projectaria_tools.core import calibration | |
| from projectaria_tools.core.image import InterpolationMethod | |
| class ASEDataset(Dataset): | |
| def __init__( | |
| self, | |
| data_root: str, | |
| img_mean = (0.485, 0.456, 0.406), | |
| img_std = (0.229, 0.224, 0.225), | |
| N: int = 8, | |
| image_size: int = 1024, | |
| verbose = False, | |
| dataset_scale = 1, | |
| continuous_prob = 0, | |
| invalid_classes = ['ceiling', 'wall', 'empty_space', 'background', 'floor', 'window'], | |
| valid_must3r_sizes = [224, 512] | |
| ): | |
| self.verbose = verbose | |
| self.data_root = data_root | |
| self.dataset_scale = dataset_scale | |
| self.continuous_prob = continuous_prob | |
| self.N = N | |
| self.image_transform = T.Compose([ | |
| T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), | |
| T.Normalize(mean = img_mean, std = img_std) | |
| ]) | |
| self.instance_transform = T.Compose([ | |
| T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), | |
| ]) | |
| self.valid_must3r_sizes = valid_must3r_sizes | |
| from projectaria_tools.projects import ase | |
| from projectaria_tools.core import calibration | |
| self.ase_device = ase.get_ase_rgb_calibration() | |
| self.ase_width, self.ase_height = self.ase_device.get_image_size() | |
| assert self.ase_width == self.ase_height, f"Expected square images, got {self.ase_width}x{self.ase_height}" | |
| self.ase_pinhole = calibration.get_linear_camera_calibration( | |
| self.ase_width, self.ase_height, 320, "camera-rgb", self.ase_device.get_transform_device_camera() | |
| ) | |
| self.fx, self.fy = self.ase_pinhole.get_focal_lengths() | |
| self.cx, self.cy = self.ase_pinhole.get_principal_point() | |
| self.K = np.array([[self.fx, 0, self.cx], | |
| [0, self.fy, self.cy], | |
| [0, 0, 1 ]], dtype = np.float32) | |
| self.videos = os.listdir(os.path.join(data_root)) | |
| self.frames = {} | |
| self.masks = {} | |
| self.must3r_feats = {} | |
| self.appearances = {} | |
| self.mask2indices = {} | |
| self.validindices = {} | |
| self.indices = [] | |
| for video in tqdm(self.videos, desc='Loading ASE videos'): | |
| if not os.path.isdir(os.path.join(data_root, video)): | |
| print(f"skip {video} as not a directory") | |
| continue | |
| frames = sorted(glob(os.path.join(data_root, video, 'undistorted', '*.jpg'))) | |
| masks = sorted(glob(os.path.join(data_root, video, 'undistorted-instances', '*.png'))) | |
| must3r_feats = sorted(glob(os.path.join(data_root, video, 'must3r-features', '*.pt'))) | |
| if not (len(must3r_feats) == len(frames) == len(masks)): | |
| if self.verbose: | |
| print(f"skip {video} as {len(must3r_feats)=}, {len(frames)=}, {len(masks)=} in {video}") | |
| continue | |
| assert all([os.path.splitext(os.path.basename(must3r_feat))[0] == os.path.splitext(os.path.basename(frame))[0] for must3r_feat, frame in zip(must3r_feats, frames)]), f'Must3r features and frames do not match in {video}' | |
| if len(frames) < self.N: | |
| if self.verbose: | |
| print(f"skip video {video} as not enough frames") | |
| continue | |
| self.frames[video] = frames | |
| self.masks[video] = masks | |
| self.must3r_feats[video] = must3r_feats | |
| self.appearances[video] = json.load(open(os.path.join(data_root, video, 'instances-appearances.json'))) | |
| self.mask2indices[video] = { | |
| os.path.basename(m): i for i, m in enumerate(masks) | |
| } | |
| self.indices += [(video, idx) for idx in range(len(frames) - self.N + 1)] | |
| self.validindices[video] = [int(instance_id) for instance_id, class_name in json.load(open(os.path.join(data_root, video, 'object_instances_to_classes.json'))).items() if class_name not in invalid_classes] # if os.path.exists(os.path.join(data_root, video, 'object_instances_to_classes.json')) else None | |
| print(f'Found {len(self.indices)} frames, and {len(self.frames)} videos, with min length {min([len(self.frames[video]) for video in self.frames])} and max length {max([len(self.frames[video]) for video in self.frames])} and {sum([(len(ids) if ids is not None else 0) for ids in self.validindices.values()])} valid instances') | |
| self._log_path = "./ase_dataset_resample.log" | |
| def __len__(self): | |
| return len(self.indices) * self.dataset_scale | |
| def __getitem__(self, idx): | |
| idx = idx % len(self.indices) | |
| video, idx = self.indices[idx] | |
| ## 1. Randomly shuffle frames | |
| choices = np.delete(np.arange(len(self.frames[video]) - self.N + 1), idx) | |
| sampled_indices = [idx] + np.random.choice(choices, size = len(choices), replace = False).tolist() | |
| ## 2. Find unique instance IDs in the first frame | |
| unique_ids = None | |
| while unique_ids is None or len(unique_ids) == 0: | |
| if unique_ids is not None: | |
| sampled_indices.pop(0) | |
| if len(sampled_indices) < self.N: | |
| return self[np.random.randint(len(self))] | |
| unique_ids = np.unique(np.array(Image.open(self.masks[video][sampled_indices[0]])), return_counts = False) | |
| unique_ids = unique_ids[(unique_ids != 0) & np.array([class_id in self.validindices[video] for class_id in unique_ids])] # if self.validindices[video] is not None else True | |
| first_frame_idx = sampled_indices[0] | |
| assert len(unique_ids) > 0 | |
| ## 3. Load the resize funcs of the first frame | |
| feat_len = torch.load(self.must3r_feats[video][first_frame_idx], map_location = 'cpu')[-1].shape[-2] | |
| must3r_size = original_must3r_size = (224 if feat_len == 196 else 512) | |
| is_continuous = (np.random.rand() < self.continuous_prob) or original_must3r_size not in self.valid_must3r_sizes | |
| if is_continuous: | |
| must3r_size = np.random.choice(self.valid_must3r_sizes).item() | |
| _, [resize_func] = load_images([self.frames[video][first_frame_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose) | |
| assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}' | |
| assert must3r_size != original_must3r_size or resize_func.transforms[1].size[0] * resize_func.transforms[1].size[1] == feat_len * 256, f'Expected {resize_func.transforms[1].size[0]}x{resize_func.transforms[1].size[1]} to be {feat_len * 256}, got {feat_len}' | |
| for instance_id in np.random.permutation(unique_ids).tolist() + [None]: | |
| if instance_id is None: | |
| return self[np.random.randint(len(self))] | |
| if (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.2) > (resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][first_frame_idx]))) == instance_id)).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.01): | |
| break | |
| if is_continuous: | |
| sampled_indices = np.arange(first_frame_idx, min(len(self.frames[video]), first_frame_idx + self.N)).tolist() | |
| # sampled_indices += np.random.choice(first_frame_idx, size = first_frame_idx, replace = False).tolist() | |
| sampled_indices = sampled_indices[:self.N] | |
| assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}' | |
| else: | |
| sampled_indices = np.arange(first_frame_idx, len(self.frames[video])).tolist()[:2] | |
| sampled_indices = sorted(sampled_indices, key = lambda sample_idx: resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id).sum(), reverse = True) ## prioritize frames with larger masks | |
| first_frame_idx = sampled_indices[0] | |
| views, original_instances, original_imgs, filelist, extrinsics, depths, point_maps, fov_ratios = [], [], [], [], [], [], [], {} | |
| pre_sampled_len = len(sampled_indices) | |
| if len(sampled_indices) < self.N: | |
| instance_appearance_candidates = set([self.mask2indices[video][p] for p in self.appearances[video][str(instance_id)]]) - set(sampled_indices) | |
| sampled_indices += np.random.permutation(list(instance_appearance_candidates)).tolist() | |
| sampled_indices += np.random.permutation(list(set(np.arange(len(self.frames[video])).tolist()) - set(instance_appearance_candidates) - set(sampled_indices))).tolist() | |
| trajectory = read_trajectory_file(os.path.join(self.data_root, video, 'trajectory.csv')) | |
| while len(views) < self.N and len(sampled_indices) >= self.N: | |
| sample_idx = sampled_indices[len(views)] | |
| [view], [resize_func] = load_images([self.frames[video][sample_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose) | |
| instance_map = resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx])) == instance_id)) | |
| if len(views) >= pre_sampled_len and not (instance_map.shape[-1] * instance_map.shape[-2] * 0.005 < instance_map.sum() < instance_map.shape[-1] * instance_map.shape[-2] * 0.25): | |
| sampled_indices.pop(len(views)) | |
| continue | |
| extrinsic = trajectory['Ts_world_from_device'][sample_idx] @ self.ase_pinhole.get_transform_device_camera().to_matrix() | |
| depth = calibration.distort_by_calibration( | |
| np.array(Image.open(self.frames[video][sample_idx].replace('undistorted', 'depth').replace('vignette', 'depth').replace('.jpg', '.png'))), self.ase_pinhole, self.ase_device, InterpolationMethod.NEAREST_NEIGHBOR | |
| ).astype(np.float32) / 1000.0 | |
| point_map = resize_func.transforms[0](torch.rot90(torch.from_numpy(depth_to_world_pointmap(depth, extrinsic, self.K).astype(np.float32)).permute(2, 0, 1), k = -1, dims = (1, 2))) | |
| assert point_map.shape[-2] == instance_map.shape[-2], f"Expected height {instance_map.shape[-2]}, got {point_map.shape[-2]}" | |
| fov_ratio = None | |
| if len(views) < pre_sampled_len or instance_map.sum().item() == 0 or \ | |
| (fov_ratio := (in_fov_ratio(point_map[:, instance_map].permute(1, 0), extrinsics[0], K = self.K, W = self.ase_height, H = self.ase_width, ## for rot -90 | |
| W_crop = abs(int(self.ase_height) - original_instances[0].shape[-2]) // 2, | |
| H_crop = abs(int(self.ase_width) - original_instances[0].shape[-1]) // 2)[0])) > 0.25: | |
| views.append(view) | |
| original_instances.append(instance_map) | |
| original_imgs.append(resize_func.transforms[0](TF.to_tensor(Image.open(self.frames[video][sample_idx])))) | |
| filelist.append(self.frames[video][sample_idx]) | |
| extrinsics.append(extrinsic) | |
| depths.append(resize_func.transforms[0](torch.rot90(torch.from_numpy(depth), k = -1, dims = (0, 1)))) | |
| point_maps.append(point_map) | |
| fov_ratios[self.frames[video][sample_idx]] = fov_ratio if fov_ratio is not None else -1 | |
| else: | |
| sampled_indices.pop(len(views)) | |
| continue | |
| sampled_indices = sampled_indices[:len(views)] | |
| if len(sampled_indices) < self.N: | |
| open(self._log_path, "a").write(f"[short_span] {video}: span={len(sampled_indices)} < N={self.N}\n") | |
| return self[np.random.randint(len(self))] | |
| assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}' | |
| if not is_continuous or (np.random.rand() < 0.8 and must3r_size == original_must3r_size): | |
| assert original_must3r_size == must3r_size, f'If not continuous, must3r size should not change, got {must3r_size} and {original_must3r_size}' | |
| must3r_feats_filelist = [self.must3r_feats[video][idx] for idx in sampled_indices] | |
| must3r_feats = [torch.load(must3r_filepath, map_location = 'cpu') for must3r_filepath in must3r_feats_filelist] | |
| must3r_feats_head = torch.cat([f[-1] for f in must3r_feats], dim = 0) | |
| must3r_feats = [f[:-1] for f in must3r_feats] | |
| must3r_feats = [torch.cat(f, dim = 0) for f in zip(*must3r_feats)] | |
| must3r_feats = [ | |
| rearrange(f, 'b (h w) c -> b c h w', h = views[0]['true_shape'][0] // 16, w = views[0]['true_shape'][1] // 16) | |
| for f in must3r_feats | |
| ] | |
| else: | |
| assert is_continuous, f'If must3r size changed, should be continuous sampling, got {must3r_size} and {original_must3r_size}' | |
| must3r_feats = None | |
| must3r_feats_head = None | |
| original_instances = torch.stack(original_instances).squeeze()[:, None] | |
| instances = self.instance_transform(original_instances) | |
| assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}' | |
| original_imgs = torch.stack(original_imgs) | |
| imgs = self.image_transform(original_imgs) | |
| # if is_continuous: | |
| # permutation = torch.arange(len(instances)) | |
| # else: | |
| # permutation = torch.argsort(instances.squeeze().sum(dim = (1, 2)), descending = True) | |
| permutation = torch.arange(len(instances)) | |
| permutation[pre_sampled_len:] = torch.randperm(len(instances) - pre_sampled_len) + pre_sampled_len | |
| return { | |
| 'original_images': original_imgs[permutation], | |
| 'images': imgs[permutation], | |
| 'original_masks': original_instances[permutation], | |
| 'masks': instances[permutation], | |
| 'filelist': [filelist[idx] for idx in permutation], | |
| 'must3r_views': [views[idx] for idx in permutation], | |
| 'must3r_size': must3r_size, | |
| 'video': video, | |
| 'instance_id': instance_id, | |
| 'dataset': 'scannetpp', | |
| 'valid_masks': torch.ones_like(instances), | |
| 'intrinsics': torch.from_numpy(self.K).unsqueeze(0).repeat(self.N, 1, 1)[permutation], | |
| 'extrinsics': torch.from_numpy(np.stack(extrinsics, axis = 0))[permutation], | |
| 'depths': torch.from_numpy(np.stack(depths, axis = 0))[permutation], | |
| 'point_maps': torch.from_numpy(np.stack(point_maps, axis = 0))[permutation], | |
| 'fov_ratios': fov_ratios, | |
| 'is_continuous': is_continuous | |
| } | ( | |
| { | |
| 'must3r_feats': [f[permutation] for f in must3r_feats], | |
| 'must3r_feats_head': must3r_feats_head[permutation], | |
| 'must3r_feats_filelist': [must3r_feats_filelist[idx] for idx in permutation], | |
| } if must3r_feats is not None else {} | |
| ) | |
| def pose_from_qwxyz_txyz(elems): | |
| qw, qx, qy, qz, tx, ty, tz = map(float, elems) | |
| pose = np.eye(4) | |
| pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() | |
| pose[:3, 3] = (tx, ty, tz) | |
| return np.linalg.inv(pose) # returns cam2world | |
| def depth_to_world_pointmap(depth, c2w, K, depth_type = 'range'): | |
| """ | |
| depth: (H,W) depth in meters, camera-Z | |
| c2w: (4,4) camera-to-world transform | |
| K: (3,3) camera intrinsics | |
| Returns: (H,W,3) world xyz (NaN for invalid depth) | |
| """ | |
| Kinv = np.linalg.inv(K) | |
| H_, W_ = depth.shape | |
| ys, xs = np.meshgrid(np.arange(H_), np.arange(W_), indexing='ij') | |
| ones = np.ones_like(xs, dtype=np.float64) | |
| pix = np.stack([xs, ys, ones], axis=-1).reshape(-1, 3).T # (3,N) | |
| rays_cam = Kinv @ pix # (3,N) | |
| z = depth.reshape(-1) # (N,) | |
| if depth_type == 'range': | |
| rays_cam = rays_cam / np.linalg.norm(rays_cam, axis = 0, keepdims = True) # (3,N) | |
| elif depth_type == 'z-buf': | |
| pass | |
| else: | |
| raise ValueError(f'Unknown depth_type {depth_type}') | |
| xyz_cam = rays_cam * z # scale each ray by depth | |
| xyz_cam_h = np.vstack([xyz_cam, np.ones_like(z)]) # (4,N) | |
| xyz_w_h = c2w @ xyz_cam_h # (4,N) | |
| xyz_w = xyz_w_h[:3].T.reshape(H_, W_, 3) | |
| mask = (depth <= 0) | ~np.isfinite(depth) | |
| xyz_w[mask] = np.nan | |
| return xyz_w | |
| def in_fov_ratio(points, c2w, K, H, W, H_crop, W_crop): | |
| """ | |
| points: (N,3) world coords, torch tensor | |
| c2w: (4,4) camera-to-world, torch tensor | |
| K: (3,3) intrinsics, torch tensor | |
| H,W: image size | |
| """ | |
| # device = points.device | |
| K = K # .to(device) | |
| # world -> camera | |
| w2c = np.linalg.inv(c2w) # .to(device) | |
| Pc = (points @ w2c[:3, :3].T) + w2c[:3, 3] | |
| X, Y, Z = Pc[:,0], Pc[:,1], Pc[:,2] | |
| # projection | |
| u = K[0, 0] * (X / Z) + K[0, 2] | |
| v = K[1, 1] * (Y / Z) + K[1, 2] | |
| mask = (Z > 0) & (u >= W_crop) & (u < W - W_crop) & (v >= H_crop) & (v < H - H_crop) | |
| return mask.float().mean(), mask | |
| class ScanNetPPV2Dataset(Dataset): | |
| def __init__( | |
| self, | |
| data_root: str, | |
| must3r_data_root: str = None, | |
| img_mean = (0.485, 0.456, 0.406), | |
| img_std = (0.229, 0.224, 0.225), | |
| N: int = 8, | |
| image_size: int = 1024, | |
| verbose = False, | |
| dataset_scale = 1, | |
| continuous_prob = 0, | |
| instance_classes_file = '<your path to scannetppv2>/metadata/semantic_benchmark/top100_instance.txt', | |
| split_file: str = '<your path to scannetppv2>/splits/nvs_sem_train.txt', | |
| excluding_scenes = ["09d6e808b4", "0f69aefe3d", "1b379f1114", "1cbb105c6a", "2c7c10379b", "46638cfd0f", "4f341f3af0", "6ef2ac745a", "898a7dfd0c", "aa852f7871", "eea4ad9c04", 'd27235711b'], ## horizontal / vertical flip issues | |
| valid_must3r_sizes = [224, 512] | |
| ): | |
| self.verbose = verbose | |
| self.data_root = data_root | |
| self.must3r_data_root = must3r_data_root if must3r_data_root is not None else data_root | |
| self.dataset_scale = dataset_scale | |
| self.excluding_scenes = excluding_scenes | |
| self.instance_classes = open(instance_classes_file).read().splitlines() | |
| self.valid_scene_names = open(split_file).read().splitlines() | |
| self.continuous_prob = continuous_prob | |
| self.N = N | |
| self.image_transform = T.Compose([ | |
| T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), | |
| T.Normalize(mean = img_mean, std = img_std) | |
| ]) | |
| self.instance_transform = T.Compose([ | |
| T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT), | |
| ]) | |
| self.valid_must3r_sizes = valid_must3r_sizes | |
| self.videos = os.listdir(os.path.join(data_root)) | |
| self.frames = {} | |
| self.masks = {} | |
| self.must3r_feats = {} | |
| self.appearances = {} | |
| self.id2label_name = {} | |
| self.intrinsics = {} | |
| self.extrinsics = {} | |
| self.indices = [] | |
| self._log_path = "./scannetppv2_dataset_resample.log" | |
| for video in tqdm(self.videos, desc = 'Loading ScanNet++V2 videos'): | |
| if video not in self.valid_scene_names or video in self.excluding_scenes: | |
| if self.verbose: | |
| print(f"skip {video} as not in split or excluded") | |
| continue | |
| if not os.path.isdir(os.path.join(data_root, video)): | |
| print(f"skip {video} as not a directory") | |
| continue | |
| if video in ['46638cfd0f']: | |
| if self.verbose: | |
| print(f"skip {video} as broken") | |
| continue | |
| masks = sorted(glob(os.path.join(self.data_root, video, 'iphone', 'render_instance', '*.png'))) | |
| if len(masks) == 0: | |
| if self.verbose: | |
| print(f"skip {video} as no masks found") | |
| continue | |
| frames = [m.replace('render_instance', 'rgb').replace('.png', '.jpg') for m in masks] | |
| must3r_feats = [m.replace(self.data_root, self.must3r_data_root).replace('iphone/render_instance', 'must3r-features').replace('.png', '.pt') for m in masks] | |
| if not all([os.path.exists(p) for p in must3r_feats[:1]]): | |
| if self.verbose: | |
| print(f"skip {video} as not all must3r features or frames exist") | |
| continue | |
| # assert all([os.path.exists(p) for p in frames]), f'Not all frames exist in {video}' | |
| self.frames[video] = frames | |
| self.masks[video] = masks | |
| self.must3r_feats[video] = must3r_feats | |
| self.appearances[video] = json.loads(open(os.path.join(data_root, video, 'scans/instance-appearances.json')).read()) | |
| self.intrinsics[video] = self.load_intrinsics(os.path.join(data_root, video, 'iphone', 'colmap', 'cameras.txt')) | |
| assert len(self.intrinsics[video]) == 1, f'Expected 1 camera, got {len(self.intrinsics[video])} in {video}' | |
| self.extrinsics[video] = os.path.join(data_root, video, 'iphone', 'colmap', 'images.txt') | |
| assert all([f_name == os.path.basename(m) for f_name, m in zip(self.appearances[video]['framenames'], self.masks[video])]), f'Frame names in appearances do not match masks in {video}' | |
| self.id2label_name[video] = json.loads(open(os.path.join(data_root, video, 'scans/instance_id2label_name.json')).read()) | |
| self.indices += [(video, idx) for idx in range(len(frames) - self.N + 1)] | |
| print(f'Found {len(self.indices)} frames, and {len(self.frames)} videos, with min length {min([len(self.frames[video]) for video in self.frames])} and max length {max([len(self.frames[video]) for video in self.frames])}') | |
| def load_intrinsics(self, path): | |
| with open(path, 'r') as f: | |
| raw = f.read().splitlines()[3:] # skip header | |
| intrinsics = {} | |
| for camera in tqdm(raw, position = 1, leave = False): | |
| camera = camera.split(' ') | |
| intrinsics[int(camera[0])] = [camera[1]] + [float(cam) for cam in camera[2:]] | |
| return intrinsics | |
| def __len__(self): | |
| return len(self.indices) * self.dataset_scale | |
| def __getitem__(self, idx): | |
| idx = idx % len(self.indices) | |
| video, idx = self.indices[idx] | |
| if len(glob(os.path.join(self.data_root, video, 'iphone/depth/*.png'))) == 0: | |
| return self[np.random.randint(len(self))] | |
| ## 1. Randomly shuffle frames | |
| choices = np.delete(np.arange(len(self.frames[video]) - self.N + 1), idx) | |
| sampled_indices = [idx] + np.random.choice(choices, size = len(choices), replace = False).tolist() | |
| ## 2. Find unique instance IDs in the first frame | |
| unique_ids = None | |
| while unique_ids is None or len(unique_ids) == 0: | |
| if unique_ids is not None: | |
| sampled_indices.pop(0) | |
| if len(sampled_indices) == 0: | |
| return self[np.random.randint(len(self))] | |
| unique_ids, _ = np.unique(np.array(Image.open(self.masks[video][sampled_indices[0]])), return_counts = True) | |
| unique_ids = unique_ids[np.array([class_id not in [0, 65535] and self.id2label_name[video][str(class_id)] in self.instance_classes and all([s not in self.id2label_name[video][str(class_id)].lower() for s in ['wall', 'floor', 'ceiling', 'window', 'curtain', 'blind', 'table']]) for class_id in unique_ids])] | |
| first_frame_idx = sampled_indices[0] | |
| assert len(unique_ids) > 0 | |
| ## 3. Load the resize funcs of the first frame | |
| feat_len = torch.load(self.must3r_feats[video][first_frame_idx], map_location = 'cpu')[-1].shape[-2] | |
| must3r_size = original_must3r_size = (224 if feat_len == 196 else 512) | |
| is_continuous = (np.random.rand() < self.continuous_prob) or original_must3r_size not in self.valid_must3r_sizes | |
| if is_continuous: | |
| must3r_size = np.random.choice(self.valid_must3r_sizes).item() | |
| _, [resize_func] = load_images([self.frames[video][first_frame_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose) | |
| assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}' | |
| # assert resize_func.transforms[0].size[0] > resize_func.transforms[1].size[0], f'Expected first transform to be larger than second, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' | |
| # assert resize_func.transforms[0].size[1] / resize_func.transforms[1].size[1] == resize_func.transforms[0].size[0] / resize_func.transforms[1].size[0], f'Expected aspect ratio to be preserved, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}' | |
| assert must3r_size != original_must3r_size or resize_func.transforms[1].size[0] * resize_func.transforms[1].size[1] == feat_len * 256, f'Expected {resize_func.transforms[1].size[0]}x{resize_func.transforms[1].size[1]} to be {feat_len * 256}, got {feat_len}' | |
| for instance_id in np.random.permutation(unique_ids).tolist() + [None]: | |
| if instance_id is None: | |
| return self[np.random.randint(len(self))] | |
| if (resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][first_frame_idx]))) == instance_id)).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.01): | |
| break | |
| if is_continuous: | |
| sampled_indices = np.arange(first_frame_idx, len(self.frames[video])).tolist() | |
| # sampled_indices += np.random.permutation(list(set(np.arange(len(self.frames[video])).tolist()) - set(self.appearances[video][str(instance_id)]) - set(sampled_indices))).tolist() | |
| sampled_indices = sampled_indices[:self.N] | |
| assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}' | |
| else: | |
| sampled_indices = np.arange(first_frame_idx, len(self.frames[video])).tolist()[:2] | |
| sampled_indices = sorted(sampled_indices, key = lambda sample_idx: resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id).sum(), reverse = True) ## prioritize frames with larger masks | |
| first_frame_idx = sampled_indices[0] | |
| raw_poses = { | |
| raw.split()[-1].split('iphone/')[-1].split('video/')[-1]: raw.split()[1:-1] | |
| for raw in open(self.extrinsics[video], 'r').read().splitlines() if (not raw.startswith('#')) and len(raw.split()) > 0 | |
| } | |
| views, original_instances, original_imgs, filelist, extrinsics, raw_intrinsics, intrinsics, depths, point_maps, fov_ratios = [], [], [], [], [], [], [], [], [], {} | |
| pre_sampled_len = len(sampled_indices) | |
| if len(sampled_indices) < self.N: | |
| sampled_indices = sampled_indices + np.random.permutation(list(set(self.appearances[video][self.id2label_name[video][str(instance_id)]]) - set(sampled_indices))).tolist() + \ | |
| np.random.permutation(list(set(np.arange(len(self.frames[video])).tolist()) - set(self.appearances[video][self.id2label_name[video][str(instance_id)]]) - set(sampled_indices))).tolist() | |
| while len(views) < self.N and len(sampled_indices) >= self.N: | |
| sample_idx = sampled_indices[len(views)] | |
| [view], [resize_func] = load_images([self.frames[video][sample_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose) | |
| instance_map = resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx])) == instance_id)) | |
| if len(views) >= pre_sampled_len and (0 < instance_map.sum() < instance_map.shape[-1] * instance_map.shape[-2] * 0.01): | |
| sampled_indices.pop(len(views)) | |
| continue | |
| f_name = os.path.basename(self.frames[video][sample_idx]) | |
| extrinsic = pose_from_qwxyz_txyz(raw_poses[f_name][:-1]) | |
| raw_intrinsic = self.intrinsics[video][int(raw_poses[f_name][-1])] | |
| intrinsic = np.array([[raw_intrinsic[3], 0, raw_intrinsic[5]], | |
| [0, raw_intrinsic[4], raw_intrinsic[6]], | |
| [0, 0, 1 ]], dtype = np.float32) | |
| depth = np.array(Image.open(self.frames[video][sample_idx].replace('rgb', 'depth').replace('.jpg', '.png')).resize((int(raw_intrinsic[1]), int(raw_intrinsic[2]))), dtype = np.float32) / 1000.0 | |
| point_map = resize_func.transforms[0](torch.from_numpy(depth_to_world_pointmap(depth, extrinsic, intrinsic).astype(np.float32)).permute(2, 0, 1)) | |
| assert point_map.shape[-2] == instance_map.shape[-2] == int(raw_intrinsic[2]), f'Expected height {int(raw_intrinsic[2])}, got {point_map.shape[-2]} and {instance_map.shape[-2]}' | |
| fov_ratio = None | |
| if len(views) < pre_sampled_len or instance_map.sum().item() == 0 or \ | |
| (fov_ratio := (in_fov_ratio(point_map[:, instance_map].permute(1, 0), extrinsics[0], K = intrinsics[0], H = int(raw_intrinsics[0][2]), W = int(raw_intrinsics[0][1]), | |
| H_crop = abs(int(raw_intrinsics[0][2]) - original_instances[0].shape[-2]) // 2, | |
| W_crop = abs(int(raw_intrinsics[0][1]) - original_instances[0].shape[-1]) // 2)[0])) > 0.25: | |
| views.append(view) | |
| original_instances.append(instance_map) | |
| original_imgs.append(resize_func.transforms[0](TF.to_tensor(Image.open(self.frames[video][sample_idx])))) | |
| filelist.append(self.frames[video][sample_idx]) | |
| extrinsics.append(extrinsic) | |
| raw_intrinsics.append(raw_intrinsic) | |
| intrinsics.append(intrinsic) | |
| depths.append(resize_func.transforms[0](torch.from_numpy(depth))) | |
| point_maps.append(point_map) | |
| fov_ratios[self.frames[video][sample_idx]] = fov_ratio if fov_ratio is not None else -1 | |
| else: | |
| sampled_indices.pop(len(views)) | |
| continue | |
| sampled_indices = sampled_indices[:len(views)] | |
| if len(sampled_indices) < self.N: | |
| open(self._log_path, "a").write(f"[short_span] {video}: span={len(sampled_indices)} < N={self.N}\n") | |
| return self[np.random.randint(len(self))] | |
| assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}' | |
| if not is_continuous or (np.random.rand() < 0.8 and must3r_size == original_must3r_size): | |
| assert original_must3r_size == must3r_size, f'If not continuous, must3r size should not change, got {must3r_size} and {original_must3r_size}' | |
| must3r_feats_filelist = [self.must3r_feats[video][idx] for idx in sampled_indices] | |
| must3r_feats = [torch.load(must3r_filepath, map_location = 'cpu') for must3r_filepath in must3r_feats_filelist] | |
| must3r_feats_head = torch.cat([f[-1] for f in must3r_feats], dim = 0) | |
| must3r_feats = [f[:-1] for f in must3r_feats] | |
| must3r_feats = [torch.cat(f, dim = 0) for f in zip(*must3r_feats)] | |
| must3r_feats = [ | |
| rearrange(f, 'b (h w) c -> b c h w', h = views[0]['true_shape'][0] // 16, w = views[0]['true_shape'][1] // 16) | |
| for f in must3r_feats | |
| ] | |
| else: | |
| assert is_continuous, f'If must3r size changed, should be continuous sampling, got {must3r_size} and {original_must3r_size}' | |
| must3r_feats = None | |
| must3r_feats_head = None | |
| original_instances = torch.stack(original_instances).squeeze()[:, None] | |
| instances = self.instance_transform(original_instances) | |
| assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}' | |
| # assert instances[1:].sum() == 0, f"Only first frame should have the instance, got {instances.sum()=}" | |
| original_imgs = torch.stack(original_imgs) | |
| imgs = self.image_transform(original_imgs) | |
| # if is_continuous: | |
| # permutation = torch.arange(len(instances)) | |
| # else: | |
| # permutation = torch.argsort(instances.squeeze().sum(dim = (1, 2)), descending = True) | |
| permutation = torch.arange(len(instances)) | |
| permutation[pre_sampled_len:] = torch.randperm(len(instances) - pre_sampled_len) + pre_sampled_len | |
| return { | |
| 'original_images': original_imgs[permutation], | |
| 'images': imgs[permutation], | |
| 'original_masks': original_instances[permutation], | |
| 'masks': instances[permutation], | |
| 'filelist': [filelist[idx] for idx in permutation], | |
| 'must3r_views': [views[idx] for idx in permutation], | |
| 'must3r_size': must3r_size, | |
| 'video': video, | |
| 'instance_id': instance_id, | |
| 'dataset': 'scannetpp', | |
| 'valid_masks': torch.ones_like(instances), | |
| 'intrinsics': torch.from_numpy(np.stack(intrinsics, axis = 0))[permutation], | |
| 'extrinsics': torch.from_numpy(np.stack(extrinsics, axis = 0))[permutation], | |
| 'depths': torch.from_numpy(np.stack(depths, axis = 0))[permutation], | |
| 'point_maps': torch.from_numpy(np.stack(point_maps, axis = 0))[permutation], | |
| 'fov_ratios': fov_ratios, | |
| 'is_continuous': is_continuous, | |
| } | ( | |
| { | |
| 'must3r_feats': [f[permutation] for f in must3r_feats], | |
| 'must3r_feats_head': must3r_feats_head[permutation], | |
| 'must3r_feats_filelist': [must3r_feats_filelist[idx] for idx in permutation], | |
| } if must3r_feats is not None else {} | |
| ) |