diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2b68f5a1c2ce0c929c1f01e22e46fc267ea9a3e5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,57 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual Environment +venv/ +ENV/ +env/ +.venv + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# HuggingFace Space 临时文件 +input_images_*/ +*.glb +*.npz +flagged/ + +# 本地模型缓存(已改用 HuggingFace) +models/ + +# 日志 +*.log +logs/ + +# 测试文件 +.pytest_cache/ +.coverage +htmlcov/ + +# 系统文件 +Thumbs.db diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c0f6e3922db8931beb4f76a7fbfe22677a86a8db --- /dev/null +++ b/app.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +HuggingFace Space 入口文件 +直接导入并运行 gradio_app_v8 +""" + +import sys +from pathlib import Path + +# 添加 scripts 目录到 Python 路径 +scripts_dir = Path(__file__).parent / "scripts" +sys.path.insert(0, str(scripts_dir)) + +# 导入并运行主应用 +if __name__ == "__main__": + # 导入 gradio_app_v8(会自动启动 demo) + import gradio_app_v8 + diff --git a/mapanything/__init__.py b/mapanything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/__pycache__/__init__.cpython-312.pyc b/mapanything/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b4dfaaff3ef0e1b4c1a7b19336ecd940aaa5311 Binary files /dev/null and b/mapanything/__pycache__/__init__.cpython-312.pyc differ diff --git a/mapanything/datasets/__init__.py b/mapanything/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f03559a7104487f88ee1c9adef8aafbca34a75dd --- /dev/null +++ b/mapanything/datasets/__init__.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MapAnything Datasets +""" + +import torch + +from mapanything.datasets.wai.ase import ASEWAI # noqa +from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI # noqa +from mapanything.datasets.wai.dl3dv import DL3DVWAI # noqa +from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI # noqa +from mapanything.datasets.wai.eth3d import ETH3DWAI # noqa +from mapanything.datasets.wai.megadepth import MegaDepthWAI # noqa +from mapanything.datasets.wai.mpsd import MPSDWAI # noqa +from mapanything.datasets.wai.mvs_synth import MVSSynthWAI # noqa +from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI # noqa +from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI # noqa +from mapanything.datasets.wai.scannetpp import ScanNetPPWAI # noqa +from mapanything.datasets.wai.spring import SpringWAI # noqa +from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI # noqa +from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI # noqa +from mapanything.utils.train_tools import get_rank, get_world_size + + +def get_test_data_loader( + dataset, batch_size, num_workers=8, shuffle=False, drop_last=False, pin_mem=True +): + "Get simple PyTorch dataloader corresponding to the testing dataset" + # PyTorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + if torch.distributed.is_initialized(): + sampler = torch.utils.data.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + drop_last=drop_last, + ) + elif shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + ) + + return data_loader + + +def get_test_many_ar_data_loader( + dataset, batch_size, num_workers=8, drop_last=False, pin_mem=True +): + "Get PyTorch dataloader corresponding to the testing dataset that supports many aspect ratios" + # PyTorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + # Get BatchedMultiFeatureRandomSampler + sampler = dataset.make_sampler( + batch_size, + shuffle=True, + world_size=world_size, + rank=rank, + drop_last=drop_last, + use_dynamic_sampler=False, + ) + + # Init the data laoder + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + ) + + return data_loader + + +class DynamicBatchDatasetWrapper: + """ + Wrapper dataset that handles DynamicBatchedMultiFeatureRandomSampler output. + + The dynamic sampler returns batches (lists of tuples) instead of individual samples. + This wrapper ensures that the underlying dataset's __getitem__ method gets called + with individual tuples as expected. + """ + + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, batch_indices): + """ + Handle batch of indices from DynamicBatchedMultiFeatureRandomSampler. + + Args: + batch_indices: List of tuples like [(sample_idx, feat_idx_1, feat_idx_2, ...), ...] + + Returns: + List of samples from the underlying dataset + """ + if isinstance(batch_indices, (list, tuple)) and len(batch_indices) > 0: + # If it's a batch (list of tuples), process each item + if isinstance(batch_indices[0], (list, tuple)): + return [self.dataset[idx] for idx in batch_indices] + else: + # Single tuple, call dataset directly + return self.dataset[batch_indices] + else: + # Fallback for single index + return self.dataset[batch_indices] + + def __len__(self): + return len(self.dataset) + + def __getattr__(self, name): + # Delegate all other attributes to the wrapped dataset + return getattr(self.dataset, name) + + +def get_train_data_loader( + dataset, + max_num_of_imgs_per_gpu, + num_workers=8, + shuffle=True, + drop_last=True, + pin_mem=True, +): + "Dynamic PyTorch dataloader corresponding to the training dataset" + # PyTorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + # Get DynamicBatchedMultiFeatureRandomSampler + batch_sampler = dataset.make_sampler( + shuffle=shuffle, + world_size=world_size, + rank=rank, + drop_last=drop_last, + max_num_of_images_per_gpu=max_num_of_imgs_per_gpu, + use_dynamic_sampler=True, + ) + + # Wrap the dataset to handle batch format from dynamic sampler + wrapped_dataset = DynamicBatchDatasetWrapper(dataset) + + # Init the dynamic data loader + data_loader = torch.utils.data.DataLoader( + wrapped_dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=pin_mem, + ) + + return data_loader diff --git a/mapanything/datasets/base/__init__.py b/mapanything/datasets/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/datasets/base/base_dataset.py b/mapanything/datasets/base/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5b54e302d3d7beef0e5acc62846c3c6785fc542e --- /dev/null +++ b/mapanything/datasets/base/base_dataset.py @@ -0,0 +1,697 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Base class for MapAnything datasets. +""" + +from typing import List, Tuple, Union + +import numpy as np +import PIL +import torch +import torchvision.transforms as tvf +from scipy.spatial.transform import Rotation + +from mapanything.datasets.base.easy_dataset import EasyDataset +from mapanything.utils.cropping import ( + bbox_from_intrinsics_in_out, + camera_matrix_of_crop, + crop_image_and_other_optional_info, + rescale_image_and_other_optional_info, +) +from mapanything.utils.geometry import ( + depthmap_to_camera_coordinates, + get_absolute_pointmaps_and_rays_info, +) +from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT + + +class BaseDataset(EasyDataset): + """ + Define all basic options. + + Usage: + class MyDataset(BaseDataset): + def _get_views(self, idx): + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__( + self, + num_views: int, + variable_num_views: bool = False, + split: str = None, + covisibility_thres: float = None, + resolution: Union[int, Tuple[int, int], List[Tuple[int, int]]] = None, + principal_point_centered: bool = False, + transform: str = None, + data_norm_type: str = None, + aug_crop: int = 0, + seed: int = None, + max_num_retries: int = 5, + ): + """ + PyTorch dataset for multi-view images sampled from scenes, where the images form a single connected component. + + Args: + num_views (int): Number of views. + variable_num_views (bool): If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. + On by default for N-view train dataloader (hydra config). + split (str): 'train', 'val', 'test', etc. + covisibility_thres (float): Covisibility (%) threshold to determine if another image is a neighbor or not + resolution (int or tuple or list of tuples): Resolution of the images + principal_point_centered (bool): If True, the principal point is centered in the image. + transform (str): Transform to apply to the images. Options: + - 'colorjitter+grayscale+gaublur': + tvf.Compose([ + tvf.RandomApply([tvf.ColorJittter(0.3, 0.4, 0.2, 0.1)], p=0.75), + tvf.RandomGrayscale(p=0.05), + tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05), + ]) after ImgNorm + - 'colorjitter': tvf.ColorJittter(0.5, 0.5, 0.5, 0.1) after ImgNorm + - 'imgnorm': ImgNorm only + data_norm_type (str): Image normalization type. + For options, see UniCeption image normalization dict. + aug_crop (int): Augment crop. If int greater than 0, indicates the number of pixels to increase in target resolution. + seed (int): Seed for the random number generator. + max_num_retries (int): Maximum number of retries for loading a different sample from the dataset, if provided idx fails. + """ + self.num_views = num_views + self.variable_num_views = variable_num_views + self.num_views_min = 2 + self.split = split + self.covisibility_thres = covisibility_thres + self._set_resolutions(resolution) + self.principal_point_centered = principal_point_centered + + # Update the number of views if necessary and make it a list if variable_num_views is True + if self.variable_num_views and self.num_views > self.num_views_min: + self.num_views = list(range(self.num_views_min, self.num_views + 1)) + + # Initialize the image normalization type + if data_norm_type in IMAGE_NORMALIZATION_DICT.keys(): + self.data_norm_type = data_norm_type + image_norm = IMAGE_NORMALIZATION_DICT[data_norm_type] + ImgNorm = tvf.Compose( + [ + tvf.ToTensor(), + tvf.Normalize(mean=image_norm.mean, std=image_norm.std), + ] + ) + elif data_norm_type == "identity": + self.data_norm_type = data_norm_type + ImgNorm = tvf.Compose([tvf.ToTensor()]) + else: + raise ValueError( + f"Unknown data_norm_type: {data_norm_type}. Available options: identity or {list(IMAGE_NORMALIZATION_DICT.keys())}" + ) + + # Initialize torchvision transforms + if transform == "imgnorm": + self.transform = ImgNorm + elif transform == "colorjitter": + self.transform = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) + elif transform == "colorjitter+grayscale+gaublur": + self.transform = tvf.Compose( + [ + tvf.RandomApply([tvf.ColorJitter(0.3, 0.4, 0.2, 0.1)], p=0.75), + tvf.RandomGrayscale(p=0.05), + tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05), + ImgNorm, + ] + ) + else: + raise ValueError( + 'Unknown transform. Available options: "imgnorm", "colorjitter", "colorjitter+grayscale+gaublur"' + ) + + # Initialize the augmentation parameters + self.aug_crop = aug_crop + + # Initialize the seed for the random number generator + self.seed = seed + self._seed_offset = 0 + + # Initialize the maximum number of retries for loading a different sample from the dataset, if the first idx fails + self.max_num_retries = max_num_retries + + # Initialize the dataset type flags + self.is_metric_scale = False # by default a dataset is not metric scale, subclasses can overwrite this + self.is_synthetic = False # by default a dataset is not synthetic, subclasses can overwrite this + + def _load_data(self): + self.scenes = [] + self.num_of_scenes = len(self.scenes) + + def __len__(self): + "Length of the dataset is determined by the number of scenes in the dataset split" + return self.num_of_scenes + + def get_stats(self): + "Get the number of scenes in the dataset split" + return f"{self.num_of_scenes} scenes" + + def __repr__(self): + resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" + return ( + f"""{type(self).__name__}({self.get_stats()}, + {self.num_views=} + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace("self.", "") + .replace("\n", "") + .replace(" ", "") + ) + + def _get_views(self, idx, num_views_to_sample, resolution): + raise NotImplementedError() + + def _set_seed_offset(self, idx): + """ + Set the seed offset. This is directly added to self.seed when setting the random seed. + """ + self._seed_offset = idx + + def _set_resolutions(self, resolutions): + assert resolutions is not None, "undefined resolution" + + if isinstance(resolutions, int): + resolutions = [resolutions] + elif isinstance(resolutions, tuple): + resolutions = [resolutions] + elif isinstance(resolutions, list): + assert all(isinstance(res, tuple) for res in resolutions), ( + f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints" + ) + else: + raise ValueError( + f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints" + ) + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance(width, int), ( + f"Bad type for {width=} {type(width)=}, should be int" + ) + assert isinstance(height, int), ( + f"Bad type for {height=} {type(height)=}, should be int" + ) + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary( + self, + image, + resolution, + depthmap, + intrinsics, + additional_quantities=None, + ): + """ + Process an image by downsampling and cropping as needed to match the target resolution. + + This method performs the following operations: + 1. Converts the image to PIL.Image if necessary + 2. Crops the image centered on the principal point if requested + 3. Downsamples the image using high-quality Lanczos filtering + 4. Performs final cropping to match the target resolution + + Args: + image (numpy.ndarray or PIL.Image.Image): Input image to be processed + resolution (tuple): Target resolution as (width, height) + depthmap (numpy.ndarray): Depth map corresponding to the image + intrinsics (numpy.ndarray): Camera intrinsics matrix (3x3) + additional_quantities (dict, optional): Additional image-related data to be processed + alongside the main image with nearest interpolation. Defaults to None. + + Returns: + tuple: Processed image, depthmap, and updated intrinsics matrix. + If additional_quantities is provided, it returns those as well. + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # Cropping centered on the principal point if necessary + if self.principal_point_centered: + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + if cx < 0 or cx >= W or cy < 0 or cy >= H: + # Skip centered cropping if principal point is outside image bounds + pass + else: + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + left, top = cx - min_margin_x, cy - min_margin_y + right, bottom = cx + min_margin_x, cy + min_margin_y + crop_bbox = (left, top, right, bottom) + # Only perform the centered crop if the crop_bbox is larger than the target resolution + crop_width = right - left + crop_height = bottom - top + if crop_width > resolution[0] and crop_height > resolution[1]: + image, depthmap, intrinsics, additional_quantities = ( + crop_image_and_other_optional_info( + image=image, + crop_bbox=crop_bbox, + depthmap=depthmap, + camera_intrinsics=intrinsics, + additional_quantities=additional_quantities, + ) + ) + + # Get the target resolution for re-scaling + target_rescale_resolution = np.array(resolution) + if self.aug_crop > 1: + target_rescale_resolution += self._rng.integers(0, self.aug_crop) + + # High-quality Lanczos down-scaling if necessary + image, depthmap, intrinsics, additional_quantities = ( + rescale_image_and_other_optional_info( + image=image, + output_resolution=target_rescale_resolution, + depthmap=depthmap, + camera_intrinsics=intrinsics, + additional_quantities_to_be_resized_with_nearest=additional_quantities, + ) + ) + + # Actual cropping (if necessary) + new_intrinsics = camera_matrix_of_crop( + input_camera_matrix=intrinsics, + input_resolution=image.size, + output_resolution=resolution, + offset_factor=0.5, + ) + crop_bbox = bbox_from_intrinsics_in_out( + input_camera_matrix=intrinsics, + output_camera_matrix=new_intrinsics, + output_resolution=resolution, + ) + image, depthmap, new_intrinsics, additional_quantities = ( + crop_image_and_other_optional_info( + image=image, + crop_bbox=crop_bbox, + depthmap=depthmap, + camera_intrinsics=intrinsics, + additional_quantities=additional_quantities, + ) + ) + + # Return the output + if additional_quantities is not None: + return image, depthmap, new_intrinsics, additional_quantities + else: + return image, depthmap, new_intrinsics + + def _random_walk_sampling( + self, + scene_pairwise_covisibility, + num_of_samples, + max_retries=4, + use_bidirectional_covis=True, + ): + """ + Randomly samples S indices from an N x N covisibility matrix by forming adjacency edges such that the resulting subgraph (given by the indices) is connected. + If the current node has no new unvisited neighbors, backtracking occurs. + Retries with different starting indices if the desired number of samples is not reached, excluding previously visited components. + + Args: + scene_pairwise_covisibility : np.ndarray (mmap) + N x N covisibility matrix for the scene, where N is the number of views in the scene. + num_of_samples : int + The desired number of nodes to sample (num_of_samples < N). + max_retries : int + The maximum number of retries with different starting indices. + use_bidirectional_covis : bool + Whether to compute bidirectional covisibility by averaging row and column values. + If False, uses only row access (faster for large memory-mapped arrays). + Defaults to True. + + Returns: + np.ndarray + An array of sampled indices forming a connected subgraph. + """ + excluded_nodes = set() + best_walk = [] # To keep track of the best walk found + for _ in range(max_retries): + visited = set() + walk = [] # List to store the random walk sampling order + stack = [] # Stack for backtracking + + # Choose a random starting index that is not in the excluded set + all_nodes = set(range(len(scene_pairwise_covisibility))) + available_nodes = list(all_nodes - excluded_nodes) + if not available_nodes: + break # No more nodes to try + start = self._rng.choice(available_nodes) + walk.append(start) + visited.add(start) + stack.append(start) + + # Continue until we have sampled S indices or all expandable nodes are exhausted + while len(walk) < num_of_samples and stack: + current = stack[-1] + # Get the pairwise covisibility for the current node + if use_bidirectional_covis: + # Use bidirectional covisibility (slower for large memory-mapped arrays) + pairwise_covisibility = ( + scene_pairwise_covisibility[current, :] + + scene_pairwise_covisibility[:, current].T + ) / 2 + else: + # Use only row access (faster for large memory-mapped arrays) + pairwise_covisibility = scene_pairwise_covisibility[current, :] + # Normalize the covisibility using self covisibility + pairwise_covisibility = pairwise_covisibility / ( + pairwise_covisibility[current] + 1e-8 + ) + # Assign overlap score of zero to self-pairs + pairwise_covisibility[current] = 0 + # Threshold the covisibility to get adjacency list for the current node + adjacency_list_for_current = ( + pairwise_covisibility > self.covisibility_thres + ).astype(int) + adjacency_list_for_current = np.flatnonzero(adjacency_list_for_current) + # Get all unvisited neighbors + candidates = [ + idx for idx in adjacency_list_for_current if idx not in visited + ] # Remove visited nodes + if candidates: + # Randomly select one of the unvisited overlapping neighbors + next_node = self._rng.choice(candidates) + walk.append(next_node) + visited.add(next_node) + stack.append(next_node) + else: + # If no unvisited neighbor is available, backtrack + stack.pop() + + # Update the best walk if the current walk is larger + if len(walk) > len(best_walk): + best_walk = walk + + # If we have enough samples, return the result + if len(walk) >= num_of_samples: + return np.array(walk) + + # Add all visited nodes to the excluded set + excluded_nodes.update(visited) + + # If all retries are exhausted and we still don't have enough samples, return the best walk found + return np.array(best_walk) + + def _sample_view_indices( + self, + num_views_to_sample, + num_views_in_scene, + scene_pairwise_covisibility, + use_bidirectional_covis=True, + ): + """ + Sample view indices from a scene based on the adjacency list and the number of views to sample. + + Args: + num_views_to_sample (int): Number of views to sample. + num_views_in_scene (int): Total number of views available in the scene. + scene_pairwise_covisibility (np.ndarray): N x N covisibility matrix for the scene, where N is the number of views in the scene. + use_bidirectional_covis (bool): Whether to compute bidirectional covisibility by averaging row and column values. + If False, uses only row access (faster for large memory-mapped arrays). + + Returns: + numpy.ndarray: Array of sampled view indices. + """ + if num_views_to_sample == num_views_in_scene: + # Select all views in the scene + view_indices = self._rng.permutation(num_views_in_scene) + elif num_views_to_sample > num_views_in_scene: + # Select all views in the scene and repeat them to get the desired number of views + view_indices = self._rng.choice( + num_views_in_scene, size=num_views_to_sample, replace=True + ) + else: + # Select a subset of single component connected views in the scene using random walk sampling + view_indices = self._random_walk_sampling( + scene_pairwise_covisibility, + num_views_to_sample, + use_bidirectional_covis=use_bidirectional_covis, + ) + # If the required num of views can't be obtained even with 4 retries, repeat existing indices to get the desired number of views + if len(view_indices) < num_views_to_sample: + view_indices = self._rng.choice( + view_indices, size=num_views_to_sample, replace=True + ) + + return view_indices + + def _getitem_fn(self, idx): + if isinstance(idx, tuple): + # The idx is a tuple if specifying the aspect-ratio or/and the number of views + if isinstance(self.num_views, int): + idx, ar_idx = idx + else: + idx, ar_idx, num_views_to_sample_idx = idx + else: + assert len(self._resolutions) == 1 + assert isinstance(self.num_views, int) + ar_idx = 0 + + # Setup the rng + if self.seed: # reseed for each _getitem_fn + # Leads to deterministic sampling where repeating self.seed and self._seed_offset yields the same multi-view set again + # Scenes will be repeated if size of dataset is artificially increased using "N @" or "N *" + # When scenes are repeated, self._seed_offset is increased to ensure new multi-view sets + # This is useful for evaluation if the number of dataset scenes is < N, yet we want unique multi-view sets each iter + self._rng = np.random.default_rng(seed=self.seed + self._seed_offset + idx) + elif not hasattr(self, "_rng"): + seed = torch.initial_seed() # this is different for each dataloader process + self._rng = np.random.default_rng(seed=seed) + + # Get the views for the given index and check that the number of views is correct + resolution = self._resolutions[ar_idx] + if isinstance(self.num_views, int): + num_views_to_sample = self.num_views + else: + num_views_to_sample = self.num_views[num_views_to_sample_idx] + views = self._get_views(idx, num_views_to_sample, resolution) + if isinstance(self.num_views, int): + assert len(views) == self.num_views + else: + assert len(views) in self.num_views + + for v, view in enumerate(views): + # Store the index and other metadata + view["idx"] = (idx, ar_idx, v) + view["is_metric_scale"] = self.is_metric_scale + view["is_synthetic"] = self.is_synthetic + + # Check the depth, intrinsics, and pose data (also other data if present) + assert "camera_intrinsics" in view + assert "camera_pose" in view + assert np.isfinite(view["camera_pose"]).all(), ( + f"NaN or infinite values in camera pose for view {view_name(view)}" + ) + assert np.isfinite(view["depthmap"]).all(), ( + f"NaN or infinite values in depthmap for view {view_name(view)}" + ) + assert "valid_mask" not in view + assert "pts3d" not in view, ( + f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + ) + if "prior_depth_z" in view: + assert np.isfinite(view["prior_depth_z"]).all(), ( + f"NaN or infinite values in prior_depth_z for view {view_name(view)}" + ) + if "non_ambiguous_mask" in view: + assert np.isfinite(view["non_ambiguous_mask"]).all(), ( + f"NaN or infinite values in non_ambiguous_mask for view {view_name(view)}" + ) + + # Encode the image + width, height = view["img"].size + view["true_shape"] = np.int32((height, width)) + view["img"] = self.transform(view["img"]) + view["data_norm_type"] = self.data_norm_type + + # Compute the pointmaps, raymap and depth along ray + ( + pts3d, + valid_mask, + ray_origins_world, + ray_directions_world, + depth_along_ray, + ray_directions_cam, + pts3d_cam, + ) = get_absolute_pointmaps_and_rays_info(**view) + view["pts3d"] = pts3d + view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) + view["depth_along_ray"] = depth_along_ray + view["ray_directions_cam"] = ray_directions_cam + view["pts3d_cam"] = pts3d_cam + + # Compute the prior depth along ray if present + if "prior_depth_z" in view: + prior_pts3d, _ = depthmap_to_camera_coordinates( + view["prior_depth_z"], view["camera_intrinsics"] + ) + view["prior_depth_along_ray"] = np.linalg.norm(prior_pts3d, axis=-1) + view["prior_depth_along_ray"] = view["prior_depth_along_ray"][..., None] + del view["prior_depth_z"] + + # Convert ambiguous mask dtype to match valid mask dtype + if "non_ambiguous_mask" in view: + view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype( + view["valid_mask"].dtype + ) + else: + ambiguous_mask = view["depthmap"] < 0 + view["non_ambiguous_mask"] = ~ambiguous_mask + view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype( + view["valid_mask"].dtype + ) + + # Check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + + # Check shapes + assert view["depthmap"].shape == view["img"].shape[1:] + assert view["depthmap"].shape == view["pts3d"].shape[:2] + assert view["depthmap"].shape == view["valid_mask"].shape + assert view["depthmap"].shape == view["depth_along_ray"].shape[:2] + assert view["depthmap"].shape == view["ray_directions_cam"].shape[:2] + assert view["depthmap"].shape == view["pts3d_cam"].shape[:2] + if "prior_depth_along_ray" in view: + assert view["depthmap"].shape == view["prior_depth_along_ray"].shape[:2] + if "non_ambiguous_mask" in view: + assert view["depthmap"].shape == view["non_ambiguous_mask"].shape + + # Expand the last dimension of the depthmap + view["depthmap"] = view["depthmap"][..., None] + + # Append RNG state to the views, this allows to check whether the RNG is in the same state each time + view["rng"] = int.from_bytes(self._rng.bytes(4), "big") + + # Compute and store the quaternions and translation for the camera poses + # Notation is (x, y, z, w) for quaternions + # This also ensures that the camera poses have a positive determinant (right-handed coordinate system) + view["camera_pose_quats"] = ( + Rotation.from_matrix(view["camera_pose"][:3, :3]) + .as_quat() + .astype(view["camera_pose"].dtype) + ) + view["camera_pose_trans"] = view["camera_pose"][:3, 3].astype( + view["camera_pose"].dtype + ) + + # Check the pointmaps, rays, depth along ray, and camera pose quaternions and translation to ensure they are finite + assert np.isfinite(view["pts3d"]).all(), ( + f"NaN in pts3d for view {view_name(view)}" + ) + assert np.isfinite(view["valid_mask"]).all(), ( + f"NaN in valid_mask for view {view_name(view)}" + ) + assert np.isfinite(view["depth_along_ray"]).all(), ( + f"NaN in depth_along_ray for view {view_name(view)}" + ) + assert np.isfinite(view["ray_directions_cam"]).all(), ( + f"NaN in ray_directions_cam for view {view_name(view)}" + ) + assert np.isfinite(view["pts3d_cam"]).all(), ( + f"NaN in pts3d_cam for view {view_name(view)}" + ) + assert np.isfinite(view["camera_pose_quats"]).all(), ( + f"NaN in camera_pose_quats for view {view_name(view)}" + ) + assert np.isfinite(view["camera_pose_trans"]).all(), ( + f"NaN in camera_pose_trans for view {view_name(view)}" + ) + if "prior_depth_along_ray" in view: + assert np.isfinite(view["prior_depth_along_ray"]).all(), ( + f"NaN in prior_depth_along_ray for view {view_name(view)}" + ) + + return views + + def __getitem__(self, idx): + if self.max_num_retries == 0: + return self._getitem_fn(idx) + + num_retries = 0 + while num_retries <= self.max_num_retries: + try: + return self._getitem_fn(idx) + except Exception as e: + scene_idx = idx[0] if isinstance(idx, tuple) else idx + print( + f"Error in {type(self).__name__}.__getitem__ for scene_idx={scene_idx}: {e}" + ) + + if num_retries >= self.max_num_retries: + print( + f"Max retries ({self.max_num_retries}) reached, raising the exception" + ) + raise e + + # Retry with a different scene index + num_retries += 1 + if isinstance(idx, tuple): + # The scene index is the first element of the tuple + idx_list = list(idx) + idx_list[0] = np.random.randint(0, len(self)) + idx = tuple(idx_list) + else: + # The scene index is idx + idx = np.random.randint(0, len(self)) + scene_idx = idx[0] if isinstance(idx, tuple) else idx + print( + f"Retrying with scene_idx={scene_idx} ({num_retries} of {self.max_num_retries})" + ) + + +def is_good_type(v): + """ + Check if a value has an acceptable data type for processing in the dataset. + + Args: + v: The value to check. + + Returns: + tuple: A tuple containing: + - bool: True if the type is acceptable, False otherwise. + - str or None: Error message if the type is not acceptable, None otherwise. + """ + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + """ + Generate a string identifier for a view based on its dataset, label, and instance. + + Args: + view (dict): Dictionary containing view information with 'dataset', 'label', and 'instance' keys. + batch_index (int, optional): Index to select from batched data. Defaults to None. + + Returns: + str: A formatted string in the form "dataset/label/instance". + """ + + def sel(x): + return x[batch_index] if batch_index not in (None, slice(None)) else x + + db = sel(view["dataset"]) + label = sel(view["label"]) + instance = sel(view["instance"]) + return f"{db}/{label}/{instance}" diff --git a/mapanything/datasets/base/batched_sampler.py b/mapanything/datasets/base/batched_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..15c77f8ec39ba1382048a9bffa94e8b2ed919ff5 --- /dev/null +++ b/mapanything/datasets/base/batched_sampler.py @@ -0,0 +1,431 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utilities for random sampling under a single or multiple constraints + +References: DUSt3R +""" + +import numpy as np +import torch + + +def round_by(total, multiple, up=False): + """ + Round a number to the nearest multiple of another number. + + Args: + total (int): The number to round + multiple (int): The multiple to round to + up (bool, optional): Whether to round up. Defaults to False. + + Returns: + int: The rounded number + """ + if up: + total = total + multiple - 1 + return (total // multiple) * multiple + + +class BatchedRandomSampler: + """ + Random sampling under a constraint: each sample in the batch has the same feature, + which is chosen randomly from a known pool of 'features' for each batch. + + For instance, the 'feature' could be the image aspect-ratio. + + The index returned is a tuple (sample_idx, feat_idx). + This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. + """ + + def __init__( + self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True + ): + """ + Args: + dataset: Dataset to sample from + batch_size: Number of samples per batch + pool_size: Integer representing the size of feature pool + world_size: Number of distributed processes + rank: Rank of the current process + drop_last: Whether to drop the last incomplete batch + """ + self.batch_size = batch_size + self.pool_size = pool_size + + self.len_dataset = N = len(dataset) + self.total_size = round_by(N, batch_size * world_size) if drop_last else N + assert world_size == 1 or drop_last, ( + "must drop the last batch in distributed mode" + ) + + # Distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + + def __len__(self): + """ + Get the length of the sampler. + + Returns: + int: The number of samples in the sampler for the current process + """ + return self.total_size // self.world_size + + def set_epoch(self, epoch): + """ + Set the epoch for this sampler. + + This should be called before each epoch to ensure proper shuffling of the data. + + Args: + epoch (int): The current epoch number + """ + self.epoch = epoch + + def __iter__(self): + """ + Iterator over the indices. + + This method generates random indices for each batch, ensuring that all samples + within a batch have the same feature index for the given feature pool. + + Yields: + tuple: A tuple containing (sample_idx, feat_idx) + """ + # Prepare RNG + if self.epoch is None: + assert self.world_size == 1 and self.rank == 0, ( + "use set_epoch() if distributed mode is used" + ) + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # Random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # Random feat_idxs (same across each batch) + n_batches = (self.total_size + self.batch_size - 1) // self.batch_size + feat_idxs = rng.integers(self.pool_size, size=n_batches) + feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) + feat_idxs = feat_idxs.ravel()[: self.total_size] + + # Put them together + idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) + + # Distributed sampler: we select a subset of batches + # Make sure the slice for each node is aligned with batch_size + size_per_proc = self.batch_size * ( + (self.total_size + self.world_size * self.batch_size - 1) + // (self.world_size * self.batch_size) + ) + idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc] + + yield from (tuple(idx) for idx in idxs) + + +class BatchedMultiFeatureRandomSampler: + """ + Random sampling under multiple constraints: each sample in the batch has the same features, + which are chosen randomly from known pools of 'features' for each batch. + + For instance, the 'features' could be the image aspect-ratio and scene type. + + The index returned is a tuple (sample_idx, feat_idx_1, feat_idx_2, ...). + This sampler ensures that each series of `batch_size` indices has the same feature indices. + """ + + def __init__( + self, dataset, batch_size, pool_sizes, world_size=1, rank=0, drop_last=True + ): + """ + Args: + dataset: Dataset to sample from + batch_size: Number of samples per batch + pool_sizes: List of integers representing the size of each feature pool + world_size: Number of distributed processes + rank: Rank of the current process + drop_last: Whether to drop the last incomplete batch + """ + self.batch_size = batch_size + self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes] + + self.len_dataset = N = len(dataset) + self.total_size = round_by(N, batch_size * world_size) if drop_last else N + assert world_size == 1 or drop_last, ( + "must drop the last batch in distributed mode" + ) + + # Distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + + def __len__(self): + """ + Get the length of the sampler. + + Returns: + int: The number of samples in the sampler for the current process + """ + return self.total_size // self.world_size + + def set_epoch(self, epoch): + """ + Set the epoch for this sampler. + + This should be called before each epoch to ensure proper shuffling of the data. + + Args: + epoch (int): The current epoch number + """ + self.epoch = epoch + + def __iter__(self): + """ + Iterator over the indices. + + This method generates random indices for each batch, ensuring that all samples + within a batch have the same feature indices for multiple features. + + Yields: + tuple: A tuple containing (sample_idx, feat_idx_1, feat_idx_2, ...) + """ + # Prepare RNG + if self.epoch is None: + assert self.world_size == 1 and self.rank == 0, ( + "use set_epoch() if distributed mode is used" + ) + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # Random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # Random feat_idxs (same across each batch) + n_batches = (self.total_size + self.batch_size - 1) // self.batch_size + + # Generate feature indices for each feature pool + all_feat_idxs = [] + for pool_size in self.pool_sizes: + feat_idxs = rng.integers(pool_size, size=n_batches) + feat_idxs = np.broadcast_to( + feat_idxs[:, None], (n_batches, self.batch_size) + ) + feat_idxs = feat_idxs.ravel()[: self.total_size] + all_feat_idxs.append(feat_idxs) + + # Put them together + idxs = np.column_stack( + [sample_idxs] + all_feat_idxs + ) # shape = (total_size, 1 + len(pool_sizes)) + + # Distributed sampler: we select a subset of batches + # Make sure the slice for each node is aligned with batch_size + size_per_proc = self.batch_size * ( + (self.total_size + self.world_size * self.batch_size - 1) + // (self.world_size * self.batch_size) + ) + idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc] + + yield from (tuple(idx) for idx in idxs) + + +class DynamicBatchedMultiFeatureRandomSampler: + """ + Random sampling under multiple constraints with dynamic batch size: + each sample in the batch has the same features, which are chosen randomly + from known pools of 'features' for each batch. + + The batch size is dynamically determined based on a specified feature index, + using a direct mapping from feature values to batch sizes. + + For instance, if one of the features is the number of images in a multi-view set, + you can specify different batch sizes for different numbers of images to optimize + GPU memory usage. This is achieved by using the feature_to_batch_size_map parameter + to directly specify what batch size to use for each feature value. + + The returned index is a list of tuples [(sample_idx, feat_idx_1, feat_idx_2, ...), ...]. + """ + + def __init__( + self, + dataset, + pool_sizes, + scaling_feature_idx=0, + feature_to_batch_size_map=None, + world_size=1, + rank=0, + drop_last=True, + ): + """ + Args: + dataset: Dataset to sample from + pool_sizes: List of integers representing the size of each feature pool + scaling_feature_idx: Index of the feature to use for determining batch size (0-based index into pool_sizes) + feature_to_batch_size_map: Optional function or dict that maps feature values directly to batch sizes. + For example, if the feature represents number of views, this maps number of views + to appropriate batch size that can fit in GPU memory. + If None, uses a default batch size of 1 for all feature values. + world_size: Number of distributed processes + rank: Rank of the current process + drop_last: Whether to drop the last incomplete batch + """ + self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes] + self.scaling_feature_idx = scaling_feature_idx + + # Ensure scaling_feature_idx is valid + if scaling_feature_idx < 0 or scaling_feature_idx >= len(self.pool_sizes): + raise ValueError( + f"scaling_feature_idx must be between 0 and {len(self.pool_sizes) - 1}" + ) + + # Set up mapping from feature values to batch sizes + self.feature_to_batch_size_map = feature_to_batch_size_map + if self.feature_to_batch_size_map is None: + # Default: batch size of 1 for all feature values + self.feature_to_batch_size_map = { + i: 1 for i in range(self.pool_sizes[scaling_feature_idx]) + } + + self.len_dataset = N = len(dataset) + + # We don't know the exact batch size yet, so we use a large number for total_size + # This will be adjusted during iteration + self.total_size = N + + # Distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + self.drop_last = drop_last + + def __len__(self): + """ + Get the approximate length of the sampler. + + Since batch size varies, this is an estimate based on the largest batch size + in the mapping, which provides a lower bound on the number of batches. + + Returns: + int: The estimated minimum number of samples in the sampler for the current process + """ + # Find the largest batch size in the mapping + if callable(self.feature_to_batch_size_map): + # If it's a function, sample some values to find the maximum + batch_sizes = [ + self.feature_to_batch_size_map(i) + for i in range(self.pool_sizes[self.scaling_feature_idx]) + ] + max_batch_size = max(batch_sizes) + else: + # If it's a dict or similar, find the maximum directly + max_batch_size = max(self.feature_to_batch_size_map.values()) + + # Ensure minimum batch size of 1 + max_batch_size = max(1, max_batch_size) + + # Estimate total batches using the largest batch size + # This gives a lower bound on the number of batches + total_batches = self.total_size // max_batch_size + if not self.drop_last and self.total_size % max_batch_size > 0: + total_batches += 1 + + # Distribute among processes + return total_batches // self.world_size + + def set_epoch(self, epoch): + """ + Set the epoch for this sampler. + + This should be called before each epoch to ensure proper shuffling of the data. + + Args: + epoch (int): The current epoch number + """ + self.epoch = epoch + + def __iter__(self): + """ + Iterator over the indices with dynamic batch sizes. + + This method generates random indices for each batch, ensuring that all samples + within a batch have the same feature indices for multiple features. + The batch size is determined directly from the feature_to_batch_size_map. + + The iterator enforces the length returned by __len__() by stopping after + exactly that many batches have been yielded for this process. + + Yields: + list of tuples: A batch of tuples, each containing (sample_idx, feat_idx_1, feat_idx_2, ...) + """ + # Prepare RNG + if self.epoch is None: + assert self.world_size == 1 and self.rank == 0, ( + "use set_epoch() if distributed mode is used" + ) + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # Random indices for the entire dataset + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # Get the target number of batches for this process (enforce strict length) + target_batches_for_process = len(self) + batches_yielded_for_process = 0 + + # Process indices in batches with dynamic sizing + idx = 0 + batch_idx = 0 # Track batch index for even distribution + while idx < len(sample_idxs) and ( + batches_yielded_for_process < target_batches_for_process + ): + # Randomly select feature indices for this batch + feat_idxs = [rng.integers(pool_size) for pool_size in self.pool_sizes] + + # Get the scaling feature value + scaling_feat = feat_idxs[self.scaling_feature_idx] + + # Get the batch size directly from the mapping + if callable(self.feature_to_batch_size_map): + batch_size = self.feature_to_batch_size_map(scaling_feat) + else: + batch_size = self.feature_to_batch_size_map.get(scaling_feat, 1) + + # Ensure minimum batch size of 1 + batch_size = max(1, batch_size) + + # Ensure we don't go beyond available samples + remaining = len(sample_idxs) - idx + if remaining < batch_size: + if self.drop_last: + break + batch_size = remaining + + # Create batch with consistent feature indices + batch = [] + for i in range(batch_size): + if idx + i < len(sample_idxs): + sample_idx = sample_idxs[idx + i] + batch.append(tuple([sample_idx] + feat_idxs)) + + # Distribute batches among processes in round-robin fashion + if len(batch) > 0 and (batch_idx % self.world_size == self.rank): + yield batch + batches_yielded_for_process += 1 + + batch_idx += 1 # Increment batch index + idx += batch_size diff --git a/mapanything/datasets/base/easy_dataset.py b/mapanything/datasets/base/easy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..55a9268e53f8a6aa456dbfe73d8ed213e639bc3a --- /dev/null +++ b/mapanything/datasets/base/easy_dataset.py @@ -0,0 +1,478 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Base dataset class that enables easy resizing and combining + +References: DUSt3R +""" + +import numpy as np + +from mapanything.datasets.base.batched_sampler import ( + BatchedMultiFeatureRandomSampler, + DynamicBatchedMultiFeatureRandomSampler, +) + + +class EasyDataset: + """ + Dataset that can be easily resized and combined. + + Examples: + --------- + 2 * dataset ==> Duplicate each element 2x + + 10 @ dataset ==> Set the size to 10 (random sampling, duplicates if necessary) + + Dataset1 + Dataset2 ==> Concatenate datasets + """ + + def __add__(self, other): + """ + Concatenate this dataset with another dataset. + + Args: + other (EasyDataset): Another dataset to concatenate with this one + + Returns: + CatDataset: A new dataset that is the concatenation of this dataset and the other + """ + return CatDataset([self, other]) + + def __rmul__(self, factor): + """ + Multiply the dataset by a factor, duplicating each element. + + Args: + factor (int): Number of times to duplicate each element + + Returns: + MulDataset: A new dataset with each element duplicated 'factor' times + """ + return MulDataset(factor, self) + + def __rmatmul__(self, factor): + """ + Resize the dataset to a specific size using random sampling. + + Args: + factor (int): The new size of the dataset + + Returns: + ResizedDataset: A new dataset with the specified size + """ + return ResizedDataset(factor, self) + + def set_epoch(self, epoch): + """ + Set the current epoch for all constituent datasets. + + Args: + epoch (int): The current epoch number + """ + pass # nothing to do by default + + def make_sampler( + self, + batch_size=None, + shuffle=True, + world_size=1, + rank=0, + drop_last=True, + max_num_of_images_per_gpu=None, + use_dynamic_sampler=True, + ): + """ + Create a sampler for this dataset. + + Args: + batch_size (int, optional): Number of samples per batch (used for non-dynamic sampler). Defaults to None. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True. + world_size (int, optional): Number of distributed processes. Defaults to 1. + rank (int, optional): Rank of the current process. Defaults to 0. + drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True. + max_num_of_images_per_gpu (int, optional): Maximum number of images per GPU for dynamic batching. Defaults to None. + use_dynamic_sampler (bool, optional): Whether to use the dynamic sampler. Defaults to True. + + Returns: + DynamicBatchedMultiFeatureRandomSampler or BatchedMultiFeatureRandomSampler: A sampler for this dataset + + Raises: + NotImplementedError: If shuffle is False + ValueError: If num_views has an invalid type or required parameters are missing + """ + if not (shuffle): + raise NotImplementedError() # cannot deal yet + + if isinstance(self.num_views, int): + num_of_aspect_ratios = len(self._resolutions) + feature_pool_sizes = [num_of_aspect_ratios] + scaling_feature_idx = 0 # Use aspect ratio as scaling feature + elif isinstance(self.num_views, list): + num_of_aspect_ratios = len(self._resolutions) + num_of_num_views = len(self.num_views) + feature_pool_sizes = [num_of_aspect_ratios, num_of_num_views] + scaling_feature_idx = 1 # Use num_views as scaling feature + else: + raise ValueError( + f"Bad type for {self.num_views=}, should be int or list of ints" + ) + + if use_dynamic_sampler: + if max_num_of_images_per_gpu is None: + raise ValueError( + "max_num_of_images_per_gpu must be provided when using dynamic sampler" + ) + + # Create feature-to-batch-size mapping + if isinstance(self.num_views, list): + # Map num_views_idx to batch size: max(1, max_num_of_images_per_gpu // (num_views_idx + dataset.num_views_min)) + feature_to_batch_size_map = {} + for num_views_idx, num_views in enumerate(self.num_views): + batch_size_for_multi_view_sets = max( + 1, max_num_of_images_per_gpu // num_views + ) + feature_to_batch_size_map[num_views_idx] = ( + batch_size_for_multi_view_sets + ) + else: + # For fixed num_views, use a simple mapping + feature_to_batch_size_map = { + 0: max(1, max_num_of_images_per_gpu // self.num_views) + } + + return DynamicBatchedMultiFeatureRandomSampler( + self, + pool_sizes=feature_pool_sizes, + scaling_feature_idx=scaling_feature_idx, + feature_to_batch_size_map=feature_to_batch_size_map, + world_size=world_size, + rank=rank, + drop_last=drop_last, + ) + else: + if batch_size is None: + raise ValueError( + "batch_size must be provided when not using dynamic sampler" + ) + + return BatchedMultiFeatureRandomSampler( + self, + batch_size, + feature_pool_sizes, + world_size=world_size, + rank=rank, + drop_last=drop_last, + ) + + +class MulDataset(EasyDataset): + """Artificially augmenting the size of a dataset.""" + + multiplicator: int + + def __init__(self, multiplicator, dataset): + """ + Initialize a dataset that artificially augments the size of another dataset. + + Args: + multiplicator (int): Factor by which to multiply the dataset size + dataset (EasyDataset): The dataset to augment + """ + assert isinstance(multiplicator, int) and multiplicator > 0 + self.multiplicator = multiplicator + self.dataset = dataset + + def __len__(self): + """ + Get the length of the dataset. + + Returns: + int: The number of samples in the dataset + """ + return self.multiplicator * len(self.dataset) + + def __repr__(self): + """ + Get a string representation of the dataset. + + Returns: + str: String representation showing the multiplication factor and the original dataset + """ + return f"{self.multiplicator}*{repr(self.dataset)}" + + def __getitem__(self, idx): + """ + Get an item from the dataset. + + Args: + idx: Index or tuple of indices to retrieve + + Returns: + The item at the specified index from the original dataset + """ + if isinstance(idx, tuple): + other = idx[1:] + idx = idx[0] + new_idx = (idx // self.multiplicator, *other) + return self.dataset[new_idx] + else: + return self.dataset[idx // self.multiplicator] + + @property + def _resolutions(self): + """ + Get the resolutions of the dataset. + + Returns: + The resolutions from the original dataset + """ + return self.dataset._resolutions + + @property + def num_views(self): + """ + Get the number of views used for the dataset. + + Returns: + int or list: The number of views parameter from the original dataset + """ + return self.dataset.num_views + + +class ResizedDataset(EasyDataset): + """Artificially changing the size of a dataset.""" + + new_size: int + + def __init__(self, new_size, dataset): + """ + Initialize a dataset with an artificially changed size. + + Args: + new_size (int): The new size of the dataset + dataset (EasyDataset): The original dataset + """ + assert isinstance(new_size, int) and new_size > 0 + self.new_size = new_size + self.dataset = dataset + + def __len__(self): + """ + Get the length of the dataset. + + Returns: + int: The new size of the dataset + """ + return self.new_size + + def __repr__(self): + """ + Get a string representation of the dataset. + + Returns: + str: String representation showing the new size and the original dataset + """ + size_str = str(self.new_size) + for i in range((len(size_str) - 1) // 3): + sep = -4 * i - 3 + size_str = size_str[:sep] + "_" + size_str[sep:] + return f"{size_str} @ {repr(self.dataset)}" + + def set_epoch(self, epoch): + """ + Set the current epoch and generate a new random mapping of indices. + + This method must be called before using __getitem__. + + Args: + epoch (int): The current epoch number + """ + # This random shuffle only depends on the epoch + rng = np.random.default_rng(seed=epoch + 777) + + # Shuffle all indices + perm = rng.permutation(len(self.dataset)) + + # Calculate how many repetitions we need + num_repetitions = 1 + (len(self) - 1) // len(self.dataset) + + # Rotary extension until target size is met + shuffled_idxs = np.concatenate([perm] * num_repetitions) + self._idxs_mapping = shuffled_idxs[: self.new_size] + + # Generate the seed offset for each repetition + # This is needed to ensure we see unique samples when we repeat a scene + seed_offset_per_repetition = [ + np.full(len(self.dataset), i) for i in range(num_repetitions) + ] + seed_offset_idxs = np.concatenate(seed_offset_per_repetition) + self._idxs_seed_offset = seed_offset_idxs[: self.new_size] + + assert len(self._idxs_mapping) == self.new_size + assert len(self._idxs_seed_offset) == self.new_size + + def __getitem__(self, idx): + """ + Get an item from the dataset. + + Args: + idx: Index or tuple of indices to retrieve + + Returns: + The item at the mapped index from the original dataset + + Raises: + AssertionError: If set_epoch has not been called + """ + assert hasattr(self, "_idxs_mapping"), ( + "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()" + ) + if isinstance(idx, tuple): + other = idx[1:] + idx = idx[0] + self.dataset._set_seed_offset(self._idxs_seed_offset[idx]) + new_idx = (self._idxs_mapping[idx], *other) + return self.dataset[new_idx] + else: + self.dataset._set_seed_offset(self._idxs_seed_offset[idx]) + return self.dataset[self._idxs_mapping[idx]] + + @property + def _resolutions(self): + """ + Get the resolutions of the dataset. + + Returns: + The resolutions from the original dataset + """ + return self.dataset._resolutions + + @property + def num_views(self): + """ + Get the number of views used for the dataset. + + Returns: + int or list: The number of views parameter from the original dataset + """ + return self.dataset.num_views + + +class CatDataset(EasyDataset): + """Concatenation of several datasets""" + + def __init__(self, datasets): + """ + Initialize a dataset that is a concatenation of several datasets. + + Args: + datasets (list): List of EasyDataset instances to concatenate + """ + for dataset in datasets: + assert isinstance(dataset, EasyDataset) + self.datasets = datasets + self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) + + def __len__(self): + """ + Get the length of the concatenated dataset. + + Returns: + int: Total number of samples across all datasets + """ + return self._cum_sizes[-1] + + def __repr__(self): + """ + Get a string representation of the concatenated dataset. + + Returns: + str: String representation showing all concatenated datasets joined by '+' + """ + # Remove uselessly long transform + return " + ".join( + repr(dataset).replace( + ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))", + "", + ) + for dataset in self.datasets + ) + + def set_epoch(self, epoch): + """ + Set the current epoch for all constituent datasets. + + Args: + epoch (int): The current epoch number + """ + for dataset in self.datasets: + dataset.set_epoch(epoch) + + def __getitem__(self, idx): + """ + Get an item from the concatenated dataset. + + Args: + idx: Index or tuple of indices to retrieve + + Returns: + The item at the specified index from the appropriate constituent dataset + + Raises: + IndexError: If the index is out of range + """ + other = None + if isinstance(idx, tuple): + other = idx[1:] + idx = idx[0] + + if not (0 <= idx < len(self)): + raise IndexError() + + db_idx = np.searchsorted(self._cum_sizes, idx, "right") + dataset = self.datasets[db_idx] + new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) + + if other is not None: + new_idx = (new_idx, *other) + return dataset[new_idx] + + @property + def _resolutions(self): + """ + Get the resolutions of the dataset. + + Returns: + The resolutions from the first dataset (all datasets must have the same resolutions) + + Raises: + AssertionError: If datasets have different resolutions + """ + resolutions = self.datasets[0]._resolutions + for dataset in self.datasets[1:]: + assert tuple(dataset._resolutions) == tuple(resolutions), ( + "All datasets must have the same resolutions" + ) + return resolutions + + @property + def num_views(self): + """ + Get the number of views used for the dataset. + + Returns: + int or list: The number of views parameter from the first dataset + + Raises: + AssertionError: If datasets have different num_views + """ + num_views = self.datasets[0].num_views + for dataset in self.datasets[1:]: + assert dataset.num_views == num_views, ( + "All datasets must have the same num_views and variable_num_views parameters" + ) + return num_views diff --git a/mapanything/datasets/utils/__init__.py b/mapanything/datasets/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/datasets/utils/data_splits.py b/mapanything/datasets/utils/data_splits.py new file mode 100644 index 0000000000000000000000000000000000000000..86983f4ff49396f686e60daac192e829d000de23 --- /dev/null +++ b/mapanything/datasets/utils/data_splits.py @@ -0,0 +1,1734 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Modules containing dataset split information +""" + + +class BlendedMVSSplits: + """ + This class contains the information about the BlendedMVS dataset splits. + """ + + def __init__(self): + """ + The splits are generated using the following logic: + # Get all seqls and seqhs using self.blendedmvs_info.all_sequences + all_sequences = self.blendedmvs_info.all_sequences + all_seqls = [int(seq[8:], 16) for seq in all_sequences] + all_seqhs = [int(seq[:8], 16) for seq in all_sequences] + # Split the seqls (& corresponding seqhs) using the DUSt3R train/val split logic + if split is None: + selection = slice(None) + elif split in ["train", "overfit"]: + # select 90% of all scenes + selection = [(seql % 10) > 0 for seql in all_seqls] + elif split == "val": + # select 10% of all scenes + selection = [(seql % 10) == 0 for seql in all_seqls] + else: + raise ValueError(f"Unknown split {split}, must be None, train, val or overfit") + # Filter sequences based on the selection + selected_seqls = [seql for seql, sel in zip(all_seqls, selection) if sel] + selected_seqhs = [seqh for seqh, sel in zip(all_seqhs, selection) if sel] + # Put them back into sequence names f"{seqh:08x}{seql:016x}" + sequence_names = [f"{seqh:08x}{seql:016x}" for seqh, seql in zip(selected_seqhs, selected_seqls)] + # Remove invalid sequence names which don't exist in self.blendedmvs_info.sequences + valid_sequences = set(self.blendedmvs_info.sequences) + valid_sequence_names = [name for name in sequence_names if name in valid_sequences] + """ + # All the 502 sequences in the dataset (totals to 115k images) + self.all_scenes = [ + "000000000000000000000000", + "00000000000000000000000a", + "00000000000000000000000b", + "00000000000000000000000c", + "00000000000000000000000d", + "00000000000000000000000e", + "00000000000000000000000f", + "000000000000000000000001", + "00000000000000000000001a", + "00000000000000000000001b", + "00000000000000000000001d", + "000000000000000000000002", + "000000000000000000000003", + "000000000000000000000004", + "000000000000000000000005", + "5a2a95f032a1c655cfe3de62", + "5a2af22b32a1c655cfe46013", + "5a2ba6de32a1c655cfe51b79", + "5a3b9731e24cd76dad1a5f1b", + "5a3ca9cb270f0e3f14d0eddb", + "5a3cb4e4270f0e3f14d12f43", + "5a03e732454a8a7ec672776c", + "5a3f4aba5889373fbbc5d3b5", + "5a4a38dad38c8a075495b5d2", + "5a5a1e48d62c7a12d5d00e47", + "5a6b1c418d100c2f8fdc4411", + "5a6feeb54a7fbc3f874f9db7", + "5a7cb1d6fe5c0d6fb53e64fb", + "5a7d3db14989e929563eb153", + "5a8aa0fab18050187cbe060e", + "5a9e5df65baeef72b4a021cd", + "5a48ba95c7dab83a7d7b44ed", + "5a48c4e9c7dab83a7d7b5cc7", + "5a48d4b2c7dab83a7d7b9851", + "5a69c47d0d5d0a7f3b2e9752", + "5a77b46b318efe6c6736e68a", + "5a355c271b63f53d5970f362", + "5a489fb1c7dab83a7d7b1070", + "5a533e8034d7582116e34209", + "5a562fc7425d0f5186314725", + "5a572fd9fc597b0478a81d14", + "5a588a8193ac3d233f77fbca", + "5a618c72784780334bc1972d", + "5a752d42acc41e2423f17674", + "5a969eea91dfc339a9a3ad2c", + "5a8315f624b8e938486e0bd8", + "5a57542f333d180827dfc132", + "5a0271884e62597cdee0d0eb", + "5a6400933d809f1d8200af15", + "5a6464143d809f1d8208c43c", + "5a563183425d0f5186314855", + "5aa0f9d7a9efce63548c69a1", + "5aa0f478a9efce63548c1cb4", + "5aa7db90bfdd572271e95246", + "5aa235f64a17b335eeaf9609", + "5aa515e613d42d091d29d300", + "5aa1196ea9efce63548ed649", + "5aaadd4cbc13235570d178a7", + "5ab6af12ac4291329b1072ab", + "5ab7e00aac4291329b15864d", + "5ab8b8e029f5351f7f2ccf59", + "5ab74bf2ac4291329b11e879", + "5ab85f1dac4291329b17cb50", + "5ab8713ba3799a1d138bd69a", + "5abc2506b53b042ead637d86", + "5acc7459a7853c4b5ebbef59", + "5acf8ca0f3d8a750097e4b15", + "5adc6bd52430a05ecb2ffb85", + "5ae2e9c5fe405c5076abc6b2", + "5af02e904c8216544b4ab5a2", + "5af28cea59bc705737003253", + "5af545d0559359053d25dcf5", + "5afacb69ab00705d0cefdd5b", + "5b2c67b5e0878c381608b8d8", + "5b3b2b9e8d46a939f933fdc0", + "5b3b353d8d46a939f93524b9", + "5b6e716d67b396324c2d77cb", + "5b6eff8b67b396324c5b2672", + "5b7a3890fc8fcf6781e2593a", + "5b21e18c58e2823a67a10dd8", + "5b60fa0c764f146feef84df0", + "5b69cc0cb44b61786eb959bf", + "5b78e57afc8fcf6781d0c3ba", + "5b192eb2170cf166458ff886", + "5b558a928bbfb62204e77ba2", + "5b864d850d072a699b32f4ae", + "5b908d3dc6ab78485f3d24a9", + "5b950c71608de421b1e7318f", + "5b4933abf2b5f44e95de482a", + "5b08286b2775267d5b0634ba", + "5b37189a35304b6f75e7583e", + "5b271079e0878c3816dacca4", + "5b22269758e2823a67a3bd03", + "5b62647143840965efc0dbde", + "5ba19a8a360c7c30c1c169df", + "5ba75d79d76ffa2c86cf2f05", + "5bb7a08aea1cfa39f1a947ab", + "5bb8a49aea1cfa39f1aa7f75", + "5bbb6eb2ea1cfa39f1af7e0c", + "5bc5f0e896b66a2cd8f9bd36", + "5bccd6beca24970bce448134", + "5bce7ac9ca24970bce4934b6", + "5bcf979a6d5f586b95c258cd", + "5bd43b4ba6b28b1ee86b92dd", + "5be3a5fb8cfdd56947f6b67c", + "5be3ae47f44e235bdbbc9771", + "5be4ab93870d330ff2dce134", + "5be47bf9b18881428d8fbc1d", + "5be883a4f98cee15019d5b83", + "5bea87f4abd34c35e1860ab5", + "5beb6e66abd34c35e18e66b9", + "5bf3a82cd439231948877aed", + "5bf7d63575c26f32dbf7413b", + "5bf17c0fd439231948355385", + "5bf26cbbd43923194854b270", + "5bf03590d4392319481971dc", + "5bf18642c50e6f7f8bdbd492", + "5bf21799d43923194842c001", + "5bfc9d5aec61ca1dd69132a2", + "5bfd0f32ec61ca1dd69dc77b", + "5bfe5ae0fe0ea555e6a969ca", + "5bff3c5cfe0ea555e6bcbf3a", + "5c0d13b795da9479e12e2ee9", + "5c1af2e2bee9a723c963d019", + "5c1b1500bee9a723c96c3e78", + "5c1dbf200843bc542d8ef8c4", + "5c1f33f1d33e1f2e4aa6dda4", + "5c2b3ed5e611832e8aed46bf", + "5c20ca3a0843bc542d94e3e2", + "5c062d84a96e33018ff6f0a6", + "5c189f2326173c3a09ed7ef3", + "5c1892f726173c3a09ea9aeb", + "5c34300a73a8df509add216d", + "5c34529873a8df509ae57b58", + "000000000000000000000006", + "000000000000000000000007", + "000000000000000000000008", + "000000000000000000000009", + "000000000000000000000010", + "000000000000000000000011", + "000000000000000000000012", + "000000000000000000000015", + "000000000000000000000016", + "000000000000000000000017", + "000000000000000000000018", + "000000000000000000000019", + "56d73ba74bd29b8c35abade2", + "56f34064e296120e10484dc4", + "57a4a7bb6b9272286e26dc18", + "57f8d9bbe73f6760f10e916a", + "58a0a2f33d0b4542479a11b1", + "58a0dd1a3d0b4542479a28f3", + "58a1a7914a4d262a170b1101", + "58a1bc804a4d262a170b2f01", + "58a1d9d14a4d262a170b58fe", + "58a01dea38486e3c98475871", + "58a1f5d74a4d262a170b65fc", + "58a2a09e156b87103d3d668c", + "58a2d9c3156b87103d3da90f", + "58a3ccb0156b87103d3e4332", + "58a3f2f8156b87103d3e5838", + "58a3f6c0156b87103d3e5971", + "58a3fc95156b87103d3e5d9b", + "58a07ce53d0b45424799fdde", + "58a07f233d0b45424799ffe7", + "58a44df2156b87103d3ee239", + "58a164f73d0b4542479a7a8e", + "58a0365e38486e3c984783eb", + "58a439cf156b87103d3ec885", + "58a464aa156b87103d3eec04", + "58a4452f156b87103d3ed55b", + "58a160983d0b4542479a7347", + "58a186444a4d262a170ae3ae", + "58a285424a4d262a170baf3e", + "58a41819156b87103d3e92a5", + "58a44463156b87103d3ed45e", + "58a47552156b87103d3f00a4", + "58c4bb4f4a69c55606122be4", + "58c6451e4a69c556061894f1", + "58ca7014affdfd07c70a95ce", + "58cf4771d0f5fb221defe6da", + "58d36897f387231e6c929903", + "58eaf1513353456af3a1682a", + "58f7f7299f5b5647873cb110", + "58f73e7c9f5b56478738929f", + "59a8f851597729752c31e7e0", + "59a452bf9b460239aa5d1c72", + "59a9619a825418241fb88191", + "59acd2f4b891807f439c8992", + "59bf97fe7e7b31545da34439", + "59c1c3e2fd6e3d4ead9f1013", + "59d2657f82ca7774b1ec081d", + "59da1fb88a126011d0394ae9", + "59e75a2ca9e91f2c5526005d", + "59e864b2a9e91f2c5529325f", + "59ecfd02e225f6492d20fcc9", + "59f37f74b45be2233001ba18", + "59f70ab1e5c5d366af29bf3e", + "59f87d0bfa6280566fb38c9a", + "59f363a8b45be22330016cad", + "564a27b26d07883f460d8ab0", + "565fb1dead14d4154dae2b94", + "567a0fb0a825d2fb79ac9a20", + "569b92eb826bcba945ca002b", + "576fefa017ce5a16397e87fd", + "584a7333fe3cb463906c9fe6", + "584aa8e9fe3cb463906cc7d0", + "584ad76bfe3cb463906ce6dc", + "584af003fe3cb463906d0e9b", + "584b9a747072670e72bfc49d", + "584b671f7072670e72bfaaf8", + "584b81747072670e72bfbbfd", + "584ba35f7072670e72bfca4d", + "584ba5977072670e72bfcc2d", + "584bc53c7072670e72bfe85f", + "584bc3997072670e72bfe58d", + "584bc4407072670e72bfe665", + "584bd5587072670e72bffe39", + "584bdadf7072670e72c0005c", + "584be5ed7072670e72c007b3", + "584c9ad27072670e72c060c5", + "584c9cc67072670e72c063a1", + "584c58b77072670e72c03990", + "584cea557072670e72c07fb4", + "584d19d47072670e72c0c6c0", + "584dfe467072670e72c1665a", + "584e875c7072670e72c1ec94", + "584e05667072670e72c17167", + "584f94e87072670e72c2d3f7", + "584fdffd7072670e72c32dc7", + "584fe07f7072670e72c32e59", + "585a2a71b338a62ad50138dc", + "585a206ab338a62ad501298f", + "585a217cb338a62ad5012b38", + "585b34afb338a62ad501e836", + "585bb25fc49c8507c3ce7812", + "585bbe55c49c8507c3ce81cd", + "585d6c8a2a57cc11d4920a1e", + "585e54c72a57cc11d492f71a", + "585e34302a57cc11d492be30", + "585ee0632a57cc11d4933608", + "585f9661712e2761468dabca", + "585ffe9a712e2761468df643", + "586a37ec9d1b5e34c28184fc", + "586a515a9d1b5e34c281b431", + "586a94939d1b5e34c2823b5d", + "586abc689d1b5e34c2826360", + "586b0e219d1b5e34c2828862", + "586b3db89d1b5e34c282cd52", + "586b4c459d1b5e34c282e66d", + "586b7d7d9d1b5e34c283359e", + "586b8f149d1b5e34c283497c", + "586b8f629d1b5e34c28349d6", + "586c4c4d9d1b5e34c28391a1", + "586c5b5b9d1b5e34c2839a5b", + "586c9fdf9d1b5e34c283b657", + "586c48329d1b5e34c2838e80", + "586caab99d1b5e34c283c213", + "586cd0779d1b5e34c28403a7", + "586d6d249d1b5e34c284b80e", + "586d8a029d1b5e34c284c948", + "586d55af9d1b5e34c284a999", + "586d07869d1b5e34c2842e5b", + "586d27489d1b5e34c28453af", + "586df9849d1b5e34c28506de", + "586e279c9d1b5e34c2852180", + "587bc5ec2366dd5d06e262c1", + "587c1abf2366dd5d06e28901", + "587c03f12366dd5d06e27722", + "587c19da2366dd5d06e2877b", + "587c31b92366dd5d06e2a9dc", + "587c87d02366dd5d06e2f989", + "587c97a52366dd5d06e30a96", + "587c45192366dd5d06e2c0eb", + "587cec702366dd5d06e37862", + "587cef0a2366dd5d06e379e3", + "587db5872366dd5d06e3e0af", + "587e2b1d2366dd5d06e41af0", + "587e2ea62366dd5d06e41f2e", + "587e5cb52366dd5d06e4486e", + "587eb1822366dd5d06e45f29", + "587f365d2366dd5d06e4906e", + "588a9c5fec4d5a1c088ec350", + "588a34cfec4d5a1c088ea8d1", + "588ab5bdec4d5a1c088ed60f", + "588aff9d90414422fbe7885a", + "588b20d290414422fbe79f40", + "588c08d590414422fbe8200b", + "588c203d90414422fbe8319e", + "588c989a90414422fbe86d96", + "588ca09d90414422fbe871a1", + "588cce2190414422fbe88520", + "588cd5ef90414422fbe8875c", + "588cf0ad90414422fbe8a20f", + "588e0d8c90414422fbe8f8b2", + "588e01c490414422fbe8ee2a", + "588e35e690414422fbe90a53", + "588f017e90414422fbe9b74b", + "588f095190414422fbe9c1ee", + "589aca717dc3d323d55671c4", + "589af2c97dc3d323d55691e8", + "589b49ea7dc3d323d556d9b4", + "589b04287dc3d323d556a185", + "589bf6a57dc3d323d55743ab", + "589c3c497dc3d323d5578468", + "589c3c577dc3d323d5578480", + "589c300f7dc3d323d5577926", + "589c24527dc3d323d5577126", + "589c35457dc3d323d5577d8d", + "589ca6a6b896147a1b73aff7", + "589d1e1fb896147a1b73ee5b", + "589d5c58b896147a1b742256", + "589d95538fa2cf375df3317b", + "589df0ffb504a864ad63521a", + "589ea316b504a864ad639a2b", + "589ec97cb504a864ad63adc3", + "589f214338486e3c9846f123", + "589fdfe738486e3c984736cf", + "590c2d70336bb52a190be886", + "590f91851225725be9e25d4e", + "591a467a6109e14d4f09b776", + "591cf3033162411cf9047f37", + "591ea44850991c70dc99a207", + "599aa591d5b41f366fed0d58", + "5643df56138263b51db1b5f3", + "5644bdac138263b51db9f669", + "5692a4c2adafac1f14201821", + "5850d4f97072670e72c425d6", + "5854c405804be105852330fe", + "5855a4fc804be1058523bd75", + "5856ac15804be105852419d8", + "5856ae8b804be10585241bae", + "5856b460804be10585242059", + "5857aa5ab338a62ad5ff4dbe", + "5857acf8b338a62ad5ff5107", + "5858db6cb338a62ad500103b", + "5858dbcab338a62ad5001081", + "5859d84fb338a62ad500e5cf", + "5861d8ea712e2761468f3cb3", + "5863edf8712e27614690cce0", + "5864a935712e2761469111b4", + "5864b076712e27614691197e", + "5864da88712e276146913d8b", + "5865f4a8712e27614691e39b", + "5867a434833dfe3f7b88edaf", + "5868cd15833dfe3f7b89bfa3", + "5880b3692366dd5d06e5d534", + "5880e3422366dd5d06e5ff8e", + "5880f0ef2366dd5d06e6166e", + "5881d2bfb6844814c136a119", + "5881f11d8ce2c2754d0714c3", + "5881fee18ce2c2754d0723f8", + "5882cda2b116682b4adebd25", + "5882d58fb116682b4adec7db", + "5884c256932ba84fbed70bf5", + "5884cc13932ba84fbed71ec4", + "5885bc5296fa095e0671a7f0", + "5886d14cb791366d617a362c", + "5888becfc02346100f4b0b21", + "5888e408c02346100f4b1a29", + "5889da66ec4d5a1c088e5187", + "5889e344ec4d5a1c088e59be", + "5889e754ec4d5a1c088e60ba", + "5890c16b90414422fbeb0262", + "5891d8ae9a8c0314c5cd30ab", + "5891d0479a8c0314c5cd2abd", + "5891ecf19a8c0314c5cd490a", + "5892c0cd9a8c0314c5cdc977", + "5894ab309a8c0314c5cee57d", + "5895a6a89a8c0314c5cfca7c", + "5895b8c29a8c0314c5cfd051", + "5895d38f9a8c0314c5cfe50c", + "5895f2329a8c0314c5d00117", + "5896bb989a8c0314c5d086b6", + "5896ebf39a8c0314c5d0a8c4", + "5898b1bac9dccc22987b7f74", + "5898b6ffc9dccc22987b8a03", + "5898b31cc9dccc22987b82ec", + "5898bbaac9dccc22987b8eba", + "5899cfa6b76d7a3780a4cb64", + "5899e5dcb76d7a3780a4ecc1", + "5947b62af1b45630bd0c2a02", + "57102be2877e1421026358af", + "57153d4031bb9900425bde85", + "57177cd7fb8d93461afc4527", + "58497cdf97b73e0b090c4273", + "58500b007072670e72c35588", + "58510bf97072670e72c46ddf", + "58522bd56789802282f2ecb3", + "58524a2e0e7012308944bcf3", + "58524a080e7012308944bcbf", + "58524c1d0e7012308944bfda", + "58524f170e7012308944c200", + "58529a4e0e70123089454c6f", + "58551bdf804be1058523556d", + "58568c9a804be10585240b03", + "58574b35804be105852455fd", + "58577c60b338a62ad5ff1564", + "58592d69b338a62ad5007a74", + "58598db2b338a62ad500bc38", + "58625f42712e2761468fb44c", + "58651bcc712e2761469166dc", + "58660e79712e27614691fe3d", + "58669aad712e27614692834c", + "58669c02712e27614692851a", + "58676c36833dfe3f7b88b7f2", + "58678b2d833dfe3f7b88e244", + "58790c82ce911104a3467c88", + "58800b0b2366dd5d06e5312d", + "58805eac2366dd5d06e56460", + "58806e422366dd5d06e57bb6", + "58831d060db9bf59bf8ab98b", + "58851ebb932ba84fbed7abad", + "58871dc3b791366d617a55ff", + "58873cabb791366d617a65a7", + "58873d44b791366d617a65dd", + "58888b3dc02346100f4af665", + "58897f62c02346100f4b8ee6", + "58933bac9a8c0314c5ce3508", + "58938e6d9a8c0314c5ce726f", + "58951cb49a8c0314c5cf4d5e", + "58970fd09a8c0314c5d0e383", + "58977ef09a8c0314c5d17b26", + "59056e6760bb961de55f3501", + "59071f2e5a6dbd3af4130f98", + "59102c811225725be9e64149", + "59338e76772c3e6384afbb15", + "59350ca084b7f26bf5ce6eb8", + "59397e493a87372f2c9e882b", + "59521e0b9096412211c2aa9d", + "59817e4a1bd4b175e7038d19", + "567884f58d2828b95e3c8eba", + "585559d9804be10585238ddf", + "585834cdb338a62ad5ffab4d", + "586082d8712e2761468e2877", + "586133c2712e2761468ecfe3", + "586281d2712e2761468fcaa2", + "586316e5712e276146903c4d", + "586326ad712e276146904571", + "586375c9712e276146907429", + "586389c9712e276146908da6", + "586496fa712e2761469108e7", + "586669c6712e27614692597a", + "586913a49d1b5e34c2808b02", + "586922da9d1b5e34c2809ff3", + "588185d8dfb7a15588a114a3", + "588305ed0db9bf59bf8a8c80", + "588315c60db9bf59bf8aa928", + "588332ee0db9bf59bf8ae9c3", + "588457b8932ba84fbed69942", + "588519d5932ba84fbed7a04a", + "588824d1b791366d617adeef", + "588857f6c02346100f4ac09f", + "589145ef90414422fbeb2e08", + "589433fa9a8c0314c5ce9656", + "589765d39a8c0314c5d16b12", + "5851165f7072670e72c4860d", + "5859341ab338a62ad500848d", + "5862388b712e2761468f84aa", + "5863915b712e276146909135", + "5866445b712e27614692383e", + "5866500d712e2761469240fd", + "5867785a833dfe3f7b88c764", + "5867969c833dfe3f7b88e8bc", + "5868040c833dfe3f7b8934f7", + "5880675a2366dd5d06e570ca", + "5882372c8ce2c2754d076af0", + "5883535e932ba84fbed5ad07", + "5888358cb791366d617af69d", + "5890330d90414422fbeaa0cb", + "5897076e9a8c0314c5d0d31b", + "5940564ec2d9527ab869f7e2", + "5947719bf1b45630bd096665", + "5948194ff1b45630bd0f47e3", + "5950206a41b158666ac50506", + "5983012d1bd4b175e70c985a", + "58586810b338a62ad5ffc20c", + "58592046b338a62ad5006b33", + "58592854b338a62ad500750a", + "58596531b338a62ad500aace", + "58818685dfb7a15588a11626", + "58829563f42b1d3ee3ec835f", + "58894345c02346100f4b51ca", + "585289980e7012308945276a", + "585369770e7012308945c709", + "585373640e7012308945cab9", + "588230658ce2c2754d076728", + "589388059a8c0314c5ce718b", + "595979485ec6a95e86a58c8d", + "5841206219d291325678ca90", + "58563650804be1058523da55", + "58564084804be1058523e116", + "58636467712e27614690661f", + "58647495712e27614690f36d", + "58654563712e276146918643", + "58664251712e276146923738", + "588084032366dd5d06e59e82", + "588159582366dd5d06e66877", + "5890279190414422fbea9734", + "5890523090414422fbeab3f0", + "5890641690414422fbeabbe7", + "585203546789802282f2aaf5", + ] + + # Final sequences to be used after filtering (some of the sequences have incorrect/low quality depth) + # Generally water bodies like lakes have incorrect depth + # Filtered out sequences: + # "5692a4c2adafac1f14201821" # Incorrect Depth + # "5864a935712e2761469111b4" # Noisy Depth and artifacts near horizon + # "59f87d0bfa6280566fb38c9a" # Object-centric, noise with background and sometimes in front of object + # "58a44463156b87103d3ed45e" # Very noisy depth in background + # "5c2b3ed5e611832e8aed46bf" # Depth occluded by artifacts + # "5bf03590d4392319481971dc" # Depth occluded by artifacts + # "00000000000000000000001a" # Largely incomplete depth + # "00000000000000000000000c" # Imprecise depth for buildings + # "000000000000000000000000" # Incorrect depth for planar terrain + self.scenes = [ + "00000000000000000000000a", + "00000000000000000000000b", + "00000000000000000000000d", + "00000000000000000000000e", + "00000000000000000000000f", + "000000000000000000000001", + "00000000000000000000001b", + "00000000000000000000001d", + "000000000000000000000002", + "000000000000000000000003", + "000000000000000000000004", + "000000000000000000000005", + "5a2a95f032a1c655cfe3de62", + "5a2af22b32a1c655cfe46013", + "5a2ba6de32a1c655cfe51b79", + "5a3b9731e24cd76dad1a5f1b", + "5a3ca9cb270f0e3f14d0eddb", + "5a3cb4e4270f0e3f14d12f43", + "5a03e732454a8a7ec672776c", + "5a3f4aba5889373fbbc5d3b5", + "5a4a38dad38c8a075495b5d2", + "5a5a1e48d62c7a12d5d00e47", + "5a6b1c418d100c2f8fdc4411", + "5a6feeb54a7fbc3f874f9db7", + "5a7cb1d6fe5c0d6fb53e64fb", + "5a7d3db14989e929563eb153", + "5a8aa0fab18050187cbe060e", + "5a9e5df65baeef72b4a021cd", + "5a48ba95c7dab83a7d7b44ed", + "5a48c4e9c7dab83a7d7b5cc7", + "5a48d4b2c7dab83a7d7b9851", + "5a69c47d0d5d0a7f3b2e9752", + "5a77b46b318efe6c6736e68a", + "5a355c271b63f53d5970f362", + "5a489fb1c7dab83a7d7b1070", + "5a533e8034d7582116e34209", + "5a562fc7425d0f5186314725", + "5a572fd9fc597b0478a81d14", + "5a588a8193ac3d233f77fbca", + "5a618c72784780334bc1972d", + "5a752d42acc41e2423f17674", + "5a969eea91dfc339a9a3ad2c", + "5a8315f624b8e938486e0bd8", + "5a57542f333d180827dfc132", + "5a0271884e62597cdee0d0eb", + "5a6400933d809f1d8200af15", + "5a6464143d809f1d8208c43c", + "5a563183425d0f5186314855", + "5aa0f9d7a9efce63548c69a1", + "5aa0f478a9efce63548c1cb4", + "5aa7db90bfdd572271e95246", + "5aa235f64a17b335eeaf9609", + "5aa515e613d42d091d29d300", + "5aa1196ea9efce63548ed649", + "5aaadd4cbc13235570d178a7", + "5ab6af12ac4291329b1072ab", + "5ab7e00aac4291329b15864d", + "5ab8b8e029f5351f7f2ccf59", + "5ab74bf2ac4291329b11e879", + "5ab85f1dac4291329b17cb50", + "5ab8713ba3799a1d138bd69a", + "5abc2506b53b042ead637d86", + "5acc7459a7853c4b5ebbef59", + "5acf8ca0f3d8a750097e4b15", + "5adc6bd52430a05ecb2ffb85", + "5ae2e9c5fe405c5076abc6b2", + "5af02e904c8216544b4ab5a2", + "5af28cea59bc705737003253", + "5af545d0559359053d25dcf5", + "5afacb69ab00705d0cefdd5b", + "5b2c67b5e0878c381608b8d8", + "5b3b2b9e8d46a939f933fdc0", + "5b3b353d8d46a939f93524b9", + "5b6e716d67b396324c2d77cb", + "5b6eff8b67b396324c5b2672", + "5b7a3890fc8fcf6781e2593a", + "5b21e18c58e2823a67a10dd8", + "5b60fa0c764f146feef84df0", + "5b69cc0cb44b61786eb959bf", + "5b78e57afc8fcf6781d0c3ba", + "5b192eb2170cf166458ff886", + "5b558a928bbfb62204e77ba2", + "5b864d850d072a699b32f4ae", + "5b908d3dc6ab78485f3d24a9", + "5b950c71608de421b1e7318f", + "5b4933abf2b5f44e95de482a", + "5b08286b2775267d5b0634ba", + "5b37189a35304b6f75e7583e", + "5b271079e0878c3816dacca4", + "5b22269758e2823a67a3bd03", + "5b62647143840965efc0dbde", + "5ba19a8a360c7c30c1c169df", + "5ba75d79d76ffa2c86cf2f05", + "5bb7a08aea1cfa39f1a947ab", + "5bb8a49aea1cfa39f1aa7f75", + "5bbb6eb2ea1cfa39f1af7e0c", + "5bc5f0e896b66a2cd8f9bd36", + "5bccd6beca24970bce448134", + "5bce7ac9ca24970bce4934b6", + "5bcf979a6d5f586b95c258cd", + "5bd43b4ba6b28b1ee86b92dd", + "5be3a5fb8cfdd56947f6b67c", + "5be3ae47f44e235bdbbc9771", + "5be4ab93870d330ff2dce134", + "5be47bf9b18881428d8fbc1d", + "5be883a4f98cee15019d5b83", + "5bea87f4abd34c35e1860ab5", + "5beb6e66abd34c35e18e66b9", + "5bf3a82cd439231948877aed", + "5bf7d63575c26f32dbf7413b", + "5bf17c0fd439231948355385", + "5bf26cbbd43923194854b270", + "5bf18642c50e6f7f8bdbd492", + "5bf21799d43923194842c001", + "5bfc9d5aec61ca1dd69132a2", + "5bfd0f32ec61ca1dd69dc77b", + "5bfe5ae0fe0ea555e6a969ca", + "5bff3c5cfe0ea555e6bcbf3a", + "5c0d13b795da9479e12e2ee9", + "5c1af2e2bee9a723c963d019", + "5c1b1500bee9a723c96c3e78", + "5c1dbf200843bc542d8ef8c4", + "5c1f33f1d33e1f2e4aa6dda4", + "5c20ca3a0843bc542d94e3e2", + "5c062d84a96e33018ff6f0a6", + "5c189f2326173c3a09ed7ef3", + "5c1892f726173c3a09ea9aeb", + "5c34300a73a8df509add216d", + "5c34529873a8df509ae57b58", + "000000000000000000000006", + "000000000000000000000007", + "000000000000000000000008", + "000000000000000000000009", + "000000000000000000000010", + "000000000000000000000011", + "000000000000000000000012", + "000000000000000000000015", + "000000000000000000000016", + "000000000000000000000017", + "000000000000000000000018", + "000000000000000000000019", + "56d73ba74bd29b8c35abade2", + "56f34064e296120e10484dc4", + "57a4a7bb6b9272286e26dc18", + "57f8d9bbe73f6760f10e916a", + "58a0a2f33d0b4542479a11b1", + "58a0dd1a3d0b4542479a28f3", + "58a1a7914a4d262a170b1101", + "58a1bc804a4d262a170b2f01", + "58a1d9d14a4d262a170b58fe", + "58a01dea38486e3c98475871", + "58a1f5d74a4d262a170b65fc", + "58a2a09e156b87103d3d668c", + "58a2d9c3156b87103d3da90f", + "58a3ccb0156b87103d3e4332", + "58a3f2f8156b87103d3e5838", + "58a3f6c0156b87103d3e5971", + "58a3fc95156b87103d3e5d9b", + "58a07ce53d0b45424799fdde", + "58a07f233d0b45424799ffe7", + "58a44df2156b87103d3ee239", + "58a164f73d0b4542479a7a8e", + "58a0365e38486e3c984783eb", + "58a439cf156b87103d3ec885", + "58a464aa156b87103d3eec04", + "58a4452f156b87103d3ed55b", + "58a160983d0b4542479a7347", + "58a186444a4d262a170ae3ae", + "58a285424a4d262a170baf3e", + "58a41819156b87103d3e92a5", + "58a47552156b87103d3f00a4", + "58c4bb4f4a69c55606122be4", + "58c6451e4a69c556061894f1", + "58ca7014affdfd07c70a95ce", + "58cf4771d0f5fb221defe6da", + "58d36897f387231e6c929903", + "58eaf1513353456af3a1682a", + "58f7f7299f5b5647873cb110", + "58f73e7c9f5b56478738929f", + "59a8f851597729752c31e7e0", + "59a452bf9b460239aa5d1c72", + "59a9619a825418241fb88191", + "59acd2f4b891807f439c8992", + "59bf97fe7e7b31545da34439", + "59c1c3e2fd6e3d4ead9f1013", + "59d2657f82ca7774b1ec081d", + "59da1fb88a126011d0394ae9", + "59e75a2ca9e91f2c5526005d", + "59e864b2a9e91f2c5529325f", + "59ecfd02e225f6492d20fcc9", + "59f37f74b45be2233001ba18", + "59f70ab1e5c5d366af29bf3e", + "59f363a8b45be22330016cad", + "564a27b26d07883f460d8ab0", + "565fb1dead14d4154dae2b94", + "567a0fb0a825d2fb79ac9a20", + "569b92eb826bcba945ca002b", + "576fefa017ce5a16397e87fd", + "584a7333fe3cb463906c9fe6", + "584aa8e9fe3cb463906cc7d0", + "584ad76bfe3cb463906ce6dc", + "584af003fe3cb463906d0e9b", + "584b9a747072670e72bfc49d", + "584b671f7072670e72bfaaf8", + "584b81747072670e72bfbbfd", + "584ba35f7072670e72bfca4d", + "584ba5977072670e72bfcc2d", + "584bc53c7072670e72bfe85f", + "584bc3997072670e72bfe58d", + "584bc4407072670e72bfe665", + "584bd5587072670e72bffe39", + "584bdadf7072670e72c0005c", + "584be5ed7072670e72c007b3", + "584c9ad27072670e72c060c5", + "584c9cc67072670e72c063a1", + "584c58b77072670e72c03990", + "584cea557072670e72c07fb4", + "584d19d47072670e72c0c6c0", + "584dfe467072670e72c1665a", + "584e875c7072670e72c1ec94", + "584e05667072670e72c17167", + "584f94e87072670e72c2d3f7", + "584fdffd7072670e72c32dc7", + "584fe07f7072670e72c32e59", + "585a2a71b338a62ad50138dc", + "585a206ab338a62ad501298f", + "585a217cb338a62ad5012b38", + "585b34afb338a62ad501e836", + "585bb25fc49c8507c3ce7812", + "585bbe55c49c8507c3ce81cd", + "585d6c8a2a57cc11d4920a1e", + "585e54c72a57cc11d492f71a", + "585e34302a57cc11d492be30", + "585ee0632a57cc11d4933608", + "585f9661712e2761468dabca", + "585ffe9a712e2761468df643", + "586a37ec9d1b5e34c28184fc", + "586a515a9d1b5e34c281b431", + "586a94939d1b5e34c2823b5d", + "586abc689d1b5e34c2826360", + "586b0e219d1b5e34c2828862", + "586b3db89d1b5e34c282cd52", + "586b4c459d1b5e34c282e66d", + "586b7d7d9d1b5e34c283359e", + "586b8f149d1b5e34c283497c", + "586b8f629d1b5e34c28349d6", + "586c4c4d9d1b5e34c28391a1", + "586c5b5b9d1b5e34c2839a5b", + "586c9fdf9d1b5e34c283b657", + "586c48329d1b5e34c2838e80", + "586caab99d1b5e34c283c213", + "586cd0779d1b5e34c28403a7", + "586d6d249d1b5e34c284b80e", + "586d8a029d1b5e34c284c948", + "586d55af9d1b5e34c284a999", + "586d07869d1b5e34c2842e5b", + "586d27489d1b5e34c28453af", + "586df9849d1b5e34c28506de", + "586e279c9d1b5e34c2852180", + "587bc5ec2366dd5d06e262c1", + "587c1abf2366dd5d06e28901", + "587c03f12366dd5d06e27722", + "587c19da2366dd5d06e2877b", + "587c31b92366dd5d06e2a9dc", + "587c87d02366dd5d06e2f989", + "587c97a52366dd5d06e30a96", + "587c45192366dd5d06e2c0eb", + "587cec702366dd5d06e37862", + "587cef0a2366dd5d06e379e3", + "587db5872366dd5d06e3e0af", + "587e2b1d2366dd5d06e41af0", + "587e2ea62366dd5d06e41f2e", + "587e5cb52366dd5d06e4486e", + "587eb1822366dd5d06e45f29", + "587f365d2366dd5d06e4906e", + "588a9c5fec4d5a1c088ec350", + "588a34cfec4d5a1c088ea8d1", + "588ab5bdec4d5a1c088ed60f", + "588aff9d90414422fbe7885a", + "588b20d290414422fbe79f40", + "588c08d590414422fbe8200b", + "588c203d90414422fbe8319e", + "588c989a90414422fbe86d96", + "588ca09d90414422fbe871a1", + "588cce2190414422fbe88520", + "588cd5ef90414422fbe8875c", + "588cf0ad90414422fbe8a20f", + "588e0d8c90414422fbe8f8b2", + "588e01c490414422fbe8ee2a", + "588e35e690414422fbe90a53", + "588f017e90414422fbe9b74b", + "588f095190414422fbe9c1ee", + "589aca717dc3d323d55671c4", + "589af2c97dc3d323d55691e8", + "589b49ea7dc3d323d556d9b4", + "589b04287dc3d323d556a185", + "589bf6a57dc3d323d55743ab", + "589c3c497dc3d323d5578468", + "589c3c577dc3d323d5578480", + "589c300f7dc3d323d5577926", + "589c24527dc3d323d5577126", + "589c35457dc3d323d5577d8d", + "589ca6a6b896147a1b73aff7", + "589d1e1fb896147a1b73ee5b", + "589d5c58b896147a1b742256", + "589d95538fa2cf375df3317b", + "589df0ffb504a864ad63521a", + "589ea316b504a864ad639a2b", + "589ec97cb504a864ad63adc3", + "589f214338486e3c9846f123", + "589fdfe738486e3c984736cf", + "590c2d70336bb52a190be886", + "590f91851225725be9e25d4e", + "591a467a6109e14d4f09b776", + "591cf3033162411cf9047f37", + "591ea44850991c70dc99a207", + "599aa591d5b41f366fed0d58", + "5643df56138263b51db1b5f3", + "5644bdac138263b51db9f669", + "5850d4f97072670e72c425d6", + "5854c405804be105852330fe", + "5855a4fc804be1058523bd75", + "5856ac15804be105852419d8", + "5856ae8b804be10585241bae", + "5856b460804be10585242059", + "5857aa5ab338a62ad5ff4dbe", + "5857acf8b338a62ad5ff5107", + "5858db6cb338a62ad500103b", + "5858dbcab338a62ad5001081", + "5859d84fb338a62ad500e5cf", + "5861d8ea712e2761468f3cb3", + "5863edf8712e27614690cce0", + "5864b076712e27614691197e", + "5864da88712e276146913d8b", + "5865f4a8712e27614691e39b", + "5867a434833dfe3f7b88edaf", + "5868cd15833dfe3f7b89bfa3", + "5880b3692366dd5d06e5d534", + "5880e3422366dd5d06e5ff8e", + "5880f0ef2366dd5d06e6166e", + "5881d2bfb6844814c136a119", + "5881f11d8ce2c2754d0714c3", + "5881fee18ce2c2754d0723f8", + "5882cda2b116682b4adebd25", + "5882d58fb116682b4adec7db", + "5884c256932ba84fbed70bf5", + "5884cc13932ba84fbed71ec4", + "5885bc5296fa095e0671a7f0", + "5886d14cb791366d617a362c", + "5888becfc02346100f4b0b21", + "5888e408c02346100f4b1a29", + "5889da66ec4d5a1c088e5187", + "5889e344ec4d5a1c088e59be", + "5889e754ec4d5a1c088e60ba", + "5890c16b90414422fbeb0262", + "5891d8ae9a8c0314c5cd30ab", + "5891d0479a8c0314c5cd2abd", + "5891ecf19a8c0314c5cd490a", + "5892c0cd9a8c0314c5cdc977", + "5894ab309a8c0314c5cee57d", + "5895a6a89a8c0314c5cfca7c", + "5895b8c29a8c0314c5cfd051", + "5895d38f9a8c0314c5cfe50c", + "5895f2329a8c0314c5d00117", + "5896bb989a8c0314c5d086b6", + "5896ebf39a8c0314c5d0a8c4", + "5898b1bac9dccc22987b7f74", + "5898b6ffc9dccc22987b8a03", + "5898b31cc9dccc22987b82ec", + "5898bbaac9dccc22987b8eba", + "5899cfa6b76d7a3780a4cb64", + "5899e5dcb76d7a3780a4ecc1", + "5947b62af1b45630bd0c2a02", + "57102be2877e1421026358af", + "57153d4031bb9900425bde85", + "57177cd7fb8d93461afc4527", + "58497cdf97b73e0b090c4273", + "58500b007072670e72c35588", + "58510bf97072670e72c46ddf", + "58522bd56789802282f2ecb3", + "58524a2e0e7012308944bcf3", + "58524a080e7012308944bcbf", + "58524c1d0e7012308944bfda", + "58524f170e7012308944c200", + "58529a4e0e70123089454c6f", + "58551bdf804be1058523556d", + "58568c9a804be10585240b03", + "58574b35804be105852455fd", + "58577c60b338a62ad5ff1564", + "58592d69b338a62ad5007a74", + "58598db2b338a62ad500bc38", + "58625f42712e2761468fb44c", + "58651bcc712e2761469166dc", + "58660e79712e27614691fe3d", + "58669aad712e27614692834c", + "58669c02712e27614692851a", + "58676c36833dfe3f7b88b7f2", + "58678b2d833dfe3f7b88e244", + "58790c82ce911104a3467c88", + "58800b0b2366dd5d06e5312d", + "58805eac2366dd5d06e56460", + "58806e422366dd5d06e57bb6", + "58831d060db9bf59bf8ab98b", + "58851ebb932ba84fbed7abad", + "58871dc3b791366d617a55ff", + "58873cabb791366d617a65a7", + "58873d44b791366d617a65dd", + "58888b3dc02346100f4af665", + "58897f62c02346100f4b8ee6", + "58933bac9a8c0314c5ce3508", + "58938e6d9a8c0314c5ce726f", + "58951cb49a8c0314c5cf4d5e", + "58970fd09a8c0314c5d0e383", + "58977ef09a8c0314c5d17b26", + "59056e6760bb961de55f3501", + "59071f2e5a6dbd3af4130f98", + "59102c811225725be9e64149", + "59338e76772c3e6384afbb15", + "59350ca084b7f26bf5ce6eb8", + "59397e493a87372f2c9e882b", + "59521e0b9096412211c2aa9d", + "59817e4a1bd4b175e7038d19", + "567884f58d2828b95e3c8eba", + "585559d9804be10585238ddf", + "585834cdb338a62ad5ffab4d", + "586082d8712e2761468e2877", + "586133c2712e2761468ecfe3", + "586281d2712e2761468fcaa2", + "586316e5712e276146903c4d", + "586326ad712e276146904571", + "586375c9712e276146907429", + "586389c9712e276146908da6", + "586496fa712e2761469108e7", + "586669c6712e27614692597a", + "586913a49d1b5e34c2808b02", + "586922da9d1b5e34c2809ff3", + "588185d8dfb7a15588a114a3", + "588305ed0db9bf59bf8a8c80", + "588315c60db9bf59bf8aa928", + "588332ee0db9bf59bf8ae9c3", + "588457b8932ba84fbed69942", + "588519d5932ba84fbed7a04a", + "588824d1b791366d617adeef", + "588857f6c02346100f4ac09f", + "589145ef90414422fbeb2e08", + "589433fa9a8c0314c5ce9656", + "589765d39a8c0314c5d16b12", + "5851165f7072670e72c4860d", + "5859341ab338a62ad500848d", + "5862388b712e2761468f84aa", + "5863915b712e276146909135", + "5866445b712e27614692383e", + "5866500d712e2761469240fd", + "5867785a833dfe3f7b88c764", + "5867969c833dfe3f7b88e8bc", + "5868040c833dfe3f7b8934f7", + "5880675a2366dd5d06e570ca", + "5882372c8ce2c2754d076af0", + "5883535e932ba84fbed5ad07", + "5888358cb791366d617af69d", + "5890330d90414422fbeaa0cb", + "5897076e9a8c0314c5d0d31b", + "5940564ec2d9527ab869f7e2", + "5947719bf1b45630bd096665", + "5948194ff1b45630bd0f47e3", + "5950206a41b158666ac50506", + "5983012d1bd4b175e70c985a", + "58586810b338a62ad5ffc20c", + "58592046b338a62ad5006b33", + "58592854b338a62ad500750a", + "58596531b338a62ad500aace", + "58818685dfb7a15588a11626", + "58829563f42b1d3ee3ec835f", + "58894345c02346100f4b51ca", + "585289980e7012308945276a", + "585369770e7012308945c709", + "585373640e7012308945cab9", + "588230658ce2c2754d076728", + "589388059a8c0314c5ce718b", + "595979485ec6a95e86a58c8d", + "5841206219d291325678ca90", + "58563650804be1058523da55", + "58564084804be1058523e116", + "58636467712e27614690661f", + "58647495712e27614690f36d", + "58654563712e276146918643", + "58664251712e276146923738", + "588084032366dd5d06e59e82", + "588159582366dd5d06e66877", + "5890279190414422fbea9734", + "5890523090414422fbeab3f0", + "5890641690414422fbeabbe7", + "585203546789802282f2aaf5", + ] + + # Train set sequences after filtering + self.train_split_scenes = [ + "00000000000000000000000b", + "00000000000000000000000d", + "00000000000000000000000e", + "00000000000000000000000f", + "000000000000000000000001", + "00000000000000000000001b", + "00000000000000000000001d", + "000000000000000000000002", + "000000000000000000000003", + "000000000000000000000004", + "000000000000000000000005", + "5a2a95f032a1c655cfe3de62", + "5a2af22b32a1c655cfe46013", + "5a2ba6de32a1c655cfe51b79", + "5a3b9731e24cd76dad1a5f1b", + "5a3ca9cb270f0e3f14d0eddb", + "5a3cb4e4270f0e3f14d12f43", + "5a03e732454a8a7ec672776c", + "5a3f4aba5889373fbbc5d3b5", + "5a5a1e48d62c7a12d5d00e47", + "5a6b1c418d100c2f8fdc4411", + "5a6feeb54a7fbc3f874f9db7", + "5a7cb1d6fe5c0d6fb53e64fb", + "5a7d3db14989e929563eb153", + "5a8aa0fab18050187cbe060e", + "5a9e5df65baeef72b4a021cd", + "5a48ba95c7dab83a7d7b44ed", + "5a48c4e9c7dab83a7d7b5cc7", + "5a48d4b2c7dab83a7d7b9851", + "5a69c47d0d5d0a7f3b2e9752", + "5a77b46b318efe6c6736e68a", + "5a355c271b63f53d5970f362", + "5a533e8034d7582116e34209", + "5a562fc7425d0f5186314725", + "5a618c72784780334bc1972d", + "5a752d42acc41e2423f17674", + "5a969eea91dfc339a9a3ad2c", + "5a8315f624b8e938486e0bd8", + "5a57542f333d180827dfc132", + "5a0271884e62597cdee0d0eb", + "5a6400933d809f1d8200af15", + "5a6464143d809f1d8208c43c", + "5a563183425d0f5186314855", + "5aa0f9d7a9efce63548c69a1", + "5aa7db90bfdd572271e95246", + "5aa235f64a17b335eeaf9609", + "5aa515e613d42d091d29d300", + "5aa1196ea9efce63548ed649", + "5aaadd4cbc13235570d178a7", + "5ab6af12ac4291329b1072ab", + "5ab7e00aac4291329b15864d", + "5ab8b8e029f5351f7f2ccf59", + "5ab74bf2ac4291329b11e879", + "5ab85f1dac4291329b17cb50", + "5ab8713ba3799a1d138bd69a", + "5abc2506b53b042ead637d86", + "5acc7459a7853c4b5ebbef59", + "5acf8ca0f3d8a750097e4b15", + "5adc6bd52430a05ecb2ffb85", + "5af02e904c8216544b4ab5a2", + "5af28cea59bc705737003253", + "5af545d0559359053d25dcf5", + "5afacb69ab00705d0cefdd5b", + "5b3b2b9e8d46a939f933fdc0", + "5b3b353d8d46a939f93524b9", + "5b6e716d67b396324c2d77cb", + "5b6eff8b67b396324c5b2672", + "5b7a3890fc8fcf6781e2593a", + "5b60fa0c764f146feef84df0", + "5b69cc0cb44b61786eb959bf", + "5b78e57afc8fcf6781d0c3ba", + "5b192eb2170cf166458ff886", + "5b558a928bbfb62204e77ba2", + "5b908d3dc6ab78485f3d24a9", + "5b950c71608de421b1e7318f", + "5b08286b2775267d5b0634ba", + "5b271079e0878c3816dacca4", + "5b22269758e2823a67a3bd03", + "5b62647143840965efc0dbde", + "5ba19a8a360c7c30c1c169df", + "5ba75d79d76ffa2c86cf2f05", + "5bb7a08aea1cfa39f1a947ab", + "5bb8a49aea1cfa39f1aa7f75", + "5bbb6eb2ea1cfa39f1af7e0c", + "5bce7ac9ca24970bce4934b6", + "5bcf979a6d5f586b95c258cd", + "5bd43b4ba6b28b1ee86b92dd", + "5be3a5fb8cfdd56947f6b67c", + "5be3ae47f44e235bdbbc9771", + "5be4ab93870d330ff2dce134", + "5be47bf9b18881428d8fbc1d", + "5be883a4f98cee15019d5b83", + "5bea87f4abd34c35e1860ab5", + "5beb6e66abd34c35e18e66b9", + "5bf3a82cd439231948877aed", + "5bf7d63575c26f32dbf7413b", + "5bf17c0fd439231948355385", + "5bf21799d43923194842c001", + "5bfd0f32ec61ca1dd69dc77b", + "5bfe5ae0fe0ea555e6a969ca", + "5c0d13b795da9479e12e2ee9", + "5c1af2e2bee9a723c963d019", + "5c1b1500bee9a723c96c3e78", + "5c1dbf200843bc542d8ef8c4", + "5c20ca3a0843bc542d94e3e2", + "5c062d84a96e33018ff6f0a6", + "5c189f2326173c3a09ed7ef3", + "5c1892f726173c3a09ea9aeb", + "5c34300a73a8df509add216d", + "000000000000000000000006", + "000000000000000000000007", + "000000000000000000000008", + "000000000000000000000009", + "000000000000000000000010", + "000000000000000000000011", + "000000000000000000000012", + "000000000000000000000015", + "000000000000000000000016", + "000000000000000000000017", + "000000000000000000000018", + "000000000000000000000019", + "56d73ba74bd29b8c35abade2", + "56f34064e296120e10484dc4", + "57a4a7bb6b9272286e26dc18", + "57f8d9bbe73f6760f10e916a", + "58a0a2f33d0b4542479a11b1", + "58a0dd1a3d0b4542479a28f3", + "58a1a7914a4d262a170b1101", + "58a1bc804a4d262a170b2f01", + "58a1d9d14a4d262a170b58fe", + "58a01dea38486e3c98475871", + "58a1f5d74a4d262a170b65fc", + "58a2a09e156b87103d3d668c", + "58a2d9c3156b87103d3da90f", + "58a3ccb0156b87103d3e4332", + "58a3f2f8156b87103d3e5838", + "58a3f6c0156b87103d3e5971", + "58a3fc95156b87103d3e5d9b", + "58a07ce53d0b45424799fdde", + "58a07f233d0b45424799ffe7", + "58a44df2156b87103d3ee239", + "58a164f73d0b4542479a7a8e", + "58a0365e38486e3c984783eb", + "58a439cf156b87103d3ec885", + "58a464aa156b87103d3eec04", + "58a4452f156b87103d3ed55b", + "58a160983d0b4542479a7347", + "58a285424a4d262a170baf3e", + "58a41819156b87103d3e92a5", + "58a47552156b87103d3f00a4", + "58c4bb4f4a69c55606122be4", + "58c6451e4a69c556061894f1", + "58ca7014affdfd07c70a95ce", + "58cf4771d0f5fb221defe6da", + "58d36897f387231e6c929903", + "58eaf1513353456af3a1682a", + "58f73e7c9f5b56478738929f", + "59a8f851597729752c31e7e0", + "59a452bf9b460239aa5d1c72", + "59a9619a825418241fb88191", + "59bf97fe7e7b31545da34439", + "59c1c3e2fd6e3d4ead9f1013", + "59d2657f82ca7774b1ec081d", + "59da1fb88a126011d0394ae9", + "59e75a2ca9e91f2c5526005d", + "59e864b2a9e91f2c5529325f", + "59ecfd02e225f6492d20fcc9", + "59f37f74b45be2233001ba18", + "59f70ab1e5c5d366af29bf3e", + "59f363a8b45be22330016cad", + "564a27b26d07883f460d8ab0", + "565fb1dead14d4154dae2b94", + "569b92eb826bcba945ca002b", + "576fefa017ce5a16397e87fd", + "584a7333fe3cb463906c9fe6", + "584aa8e9fe3cb463906cc7d0", + "584af003fe3cb463906d0e9b", + "584b9a747072670e72bfc49d", + "584b671f7072670e72bfaaf8", + "584b81747072670e72bfbbfd", + "584ba35f7072670e72bfca4d", + "584ba5977072670e72bfcc2d", + "584bc53c7072670e72bfe85f", + "584bc3997072670e72bfe58d", + "584bc4407072670e72bfe665", + "584bd5587072670e72bffe39", + "584bdadf7072670e72c0005c", + "584be5ed7072670e72c007b3", + "584c9ad27072670e72c060c5", + "584c9cc67072670e72c063a1", + "584cea557072670e72c07fb4", + "584d19d47072670e72c0c6c0", + "584dfe467072670e72c1665a", + "584e875c7072670e72c1ec94", + "584e05667072670e72c17167", + "584f94e87072670e72c2d3f7", + "584fdffd7072670e72c32dc7", + "584fe07f7072670e72c32e59", + "585a2a71b338a62ad50138dc", + "585a206ab338a62ad501298f", + "585a217cb338a62ad5012b38", + "585b34afb338a62ad501e836", + "585bb25fc49c8507c3ce7812", + "585bbe55c49c8507c3ce81cd", + "585d6c8a2a57cc11d4920a1e", + "585e54c72a57cc11d492f71a", + "585e34302a57cc11d492be30", + "585ee0632a57cc11d4933608", + "585f9661712e2761468dabca", + "585ffe9a712e2761468df643", + "586a37ec9d1b5e34c28184fc", + "586a515a9d1b5e34c281b431", + "586a94939d1b5e34c2823b5d", + "586abc689d1b5e34c2826360", + "586b0e219d1b5e34c2828862", + "586b3db89d1b5e34c282cd52", + "586b4c459d1b5e34c282e66d", + "586b7d7d9d1b5e34c283359e", + "586b8f149d1b5e34c283497c", + "586b8f629d1b5e34c28349d6", + "586c4c4d9d1b5e34c28391a1", + "586c5b5b9d1b5e34c2839a5b", + "586c9fdf9d1b5e34c283b657", + "586caab99d1b5e34c283c213", + "586cd0779d1b5e34c28403a7", + "586d6d249d1b5e34c284b80e", + "586d8a029d1b5e34c284c948", + "586d55af9d1b5e34c284a999", + "586d07869d1b5e34c2842e5b", + "586d27489d1b5e34c28453af", + "586e279c9d1b5e34c2852180", + "587bc5ec2366dd5d06e262c1", + "587c1abf2366dd5d06e28901", + "587c03f12366dd5d06e27722", + "587c19da2366dd5d06e2877b", + "587c31b92366dd5d06e2a9dc", + "587c87d02366dd5d06e2f989", + "587c97a52366dd5d06e30a96", + "587c45192366dd5d06e2c0eb", + "587cec702366dd5d06e37862", + "587cef0a2366dd5d06e379e3", + "587db5872366dd5d06e3e0af", + "587e2b1d2366dd5d06e41af0", + "587e2ea62366dd5d06e41f2e", + "587e5cb52366dd5d06e4486e", + "587eb1822366dd5d06e45f29", + "587f365d2366dd5d06e4906e", + "588a9c5fec4d5a1c088ec350", + "588a34cfec4d5a1c088ea8d1", + "588ab5bdec4d5a1c088ed60f", + "588aff9d90414422fbe7885a", + "588b20d290414422fbe79f40", + "588c08d590414422fbe8200b", + "588c203d90414422fbe8319e", + "588c989a90414422fbe86d96", + "588ca09d90414422fbe871a1", + "588cce2190414422fbe88520", + "588cd5ef90414422fbe8875c", + "588cf0ad90414422fbe8a20f", + "588e01c490414422fbe8ee2a", + "588e35e690414422fbe90a53", + "588f017e90414422fbe9b74b", + "588f095190414422fbe9c1ee", + "589aca717dc3d323d55671c4", + "589af2c97dc3d323d55691e8", + "589b49ea7dc3d323d556d9b4", + "589b04287dc3d323d556a185", + "589bf6a57dc3d323d55743ab", + "589c3c497dc3d323d5578468", + "589c3c577dc3d323d5578480", + "589c24527dc3d323d5577126", + "589c35457dc3d323d5577d8d", + "589ca6a6b896147a1b73aff7", + "589d1e1fb896147a1b73ee5b", + "589d5c58b896147a1b742256", + "589d95538fa2cf375df3317b", + "589df0ffb504a864ad63521a", + "589ea316b504a864ad639a2b", + "589ec97cb504a864ad63adc3", + "589f214338486e3c9846f123", + "589fdfe738486e3c984736cf", + "590c2d70336bb52a190be886", + "591a467a6109e14d4f09b776", + "591cf3033162411cf9047f37", + "591ea44850991c70dc99a207", + "599aa591d5b41f366fed0d58", + "5643df56138263b51db1b5f3", + "5644bdac138263b51db9f669", + "5850d4f97072670e72c425d6", + "5854c405804be105852330fe", + "5855a4fc804be1058523bd75", + "5856ac15804be105852419d8", + "5856ae8b804be10585241bae", + "5856b460804be10585242059", + "5857aa5ab338a62ad5ff4dbe", + "5857acf8b338a62ad5ff5107", + "5858db6cb338a62ad500103b", + "5858dbcab338a62ad5001081", + "5859d84fb338a62ad500e5cf", + "5861d8ea712e2761468f3cb3", + "5863edf8712e27614690cce0", + "5864b076712e27614691197e", + "5864da88712e276146913d8b", + "5865f4a8712e27614691e39b", + "5867a434833dfe3f7b88edaf", + "5868cd15833dfe3f7b89bfa3", + "5880b3692366dd5d06e5d534", + "5880e3422366dd5d06e5ff8e", + "5880f0ef2366dd5d06e6166e", + "5881d2bfb6844814c136a119", + "5881f11d8ce2c2754d0714c3", + "5881fee18ce2c2754d0723f8", + "5882cda2b116682b4adebd25", + "5882d58fb116682b4adec7db", + "5884c256932ba84fbed70bf5", + "5884cc13932ba84fbed71ec4", + "5885bc5296fa095e0671a7f0", + "5886d14cb791366d617a362c", + "5888becfc02346100f4b0b21", + "5888e408c02346100f4b1a29", + "5889da66ec4d5a1c088e5187", + "5889e754ec4d5a1c088e60ba", + "5890c16b90414422fbeb0262", + "5891d8ae9a8c0314c5cd30ab", + "5891d0479a8c0314c5cd2abd", + "5891ecf19a8c0314c5cd490a", + "5892c0cd9a8c0314c5cdc977", + "5894ab309a8c0314c5cee57d", + "5895a6a89a8c0314c5cfca7c", + "5895b8c29a8c0314c5cfd051", + "5895d38f9a8c0314c5cfe50c", + "5895f2329a8c0314c5d00117", + "5896bb989a8c0314c5d086b6", + "5896ebf39a8c0314c5d0a8c4", + "5898b1bac9dccc22987b7f74", + "5898b6ffc9dccc22987b8a03", + "5898bbaac9dccc22987b8eba", + "5899cfa6b76d7a3780a4cb64", + "5899e5dcb76d7a3780a4ecc1", + "57102be2877e1421026358af", + "57153d4031bb9900425bde85", + "57177cd7fb8d93461afc4527", + "58497cdf97b73e0b090c4273", + "58500b007072670e72c35588", + "58510bf97072670e72c46ddf", + "58522bd56789802282f2ecb3", + "58524a2e0e7012308944bcf3", + "58524a080e7012308944bcbf", + "58524c1d0e7012308944bfda", + "58524f170e7012308944c200", + "58529a4e0e70123089454c6f", + "58551bdf804be1058523556d", + "58568c9a804be10585240b03", + "58574b35804be105852455fd", + "58577c60b338a62ad5ff1564", + "58592d69b338a62ad5007a74", + "58625f42712e2761468fb44c", + "58651bcc712e2761469166dc", + "58660e79712e27614691fe3d", + "58669aad712e27614692834c", + "58676c36833dfe3f7b88b7f2", + "58678b2d833dfe3f7b88e244", + "58800b0b2366dd5d06e5312d", + "58805eac2366dd5d06e56460", + "58806e422366dd5d06e57bb6", + "58831d060db9bf59bf8ab98b", + "58851ebb932ba84fbed7abad", + "58871dc3b791366d617a55ff", + "58873cabb791366d617a65a7", + "58873d44b791366d617a65dd", + "58888b3dc02346100f4af665", + "58933bac9a8c0314c5ce3508", + "58938e6d9a8c0314c5ce726f", + "58951cb49a8c0314c5cf4d5e", + "58970fd09a8c0314c5d0e383", + "58977ef09a8c0314c5d17b26", + "59056e6760bb961de55f3501", + "59071f2e5a6dbd3af4130f98", + "59102c811225725be9e64149", + "59338e76772c3e6384afbb15", + "59350ca084b7f26bf5ce6eb8", + "59397e493a87372f2c9e882b", + "59521e0b9096412211c2aa9d", + "59817e4a1bd4b175e7038d19", + "567884f58d2828b95e3c8eba", + "585559d9804be10585238ddf", + "585834cdb338a62ad5ffab4d", + "586082d8712e2761468e2877", + "586133c2712e2761468ecfe3", + "586281d2712e2761468fcaa2", + "586316e5712e276146903c4d", + "586326ad712e276146904571", + "586375c9712e276146907429", + "586389c9712e276146908da6", + "586496fa712e2761469108e7", + "586669c6712e27614692597a", + "586913a49d1b5e34c2808b02", + "586922da9d1b5e34c2809ff3", + "588185d8dfb7a15588a114a3", + "588315c60db9bf59bf8aa928", + "588332ee0db9bf59bf8ae9c3", + "588519d5932ba84fbed7a04a", + "588824d1b791366d617adeef", + "588857f6c02346100f4ac09f", + "589145ef90414422fbeb2e08", + "589433fa9a8c0314c5ce9656", + "589765d39a8c0314c5d16b12", + "5851165f7072670e72c4860d", + "5859341ab338a62ad500848d", + "5863915b712e276146909135", + "5866445b712e27614692383e", + "5866500d712e2761469240fd", + "5867785a833dfe3f7b88c764", + "5867969c833dfe3f7b88e8bc", + "5868040c833dfe3f7b8934f7", + "5882372c8ce2c2754d076af0", + "5883535e932ba84fbed5ad07", + "5888358cb791366d617af69d", + "5890330d90414422fbeaa0cb", + "5897076e9a8c0314c5d0d31b", + "5940564ec2d9527ab869f7e2", + "5947719bf1b45630bd096665", + "5948194ff1b45630bd0f47e3", + "5950206a41b158666ac50506", + "5983012d1bd4b175e70c985a", + "58586810b338a62ad5ffc20c", + "58592046b338a62ad5006b33", + "58592854b338a62ad500750a", + "58596531b338a62ad500aace", + "58818685dfb7a15588a11626", + "58829563f42b1d3ee3ec835f", + "58894345c02346100f4b51ca", + "585289980e7012308945276a", + "585369770e7012308945c709", + "585373640e7012308945cab9", + "588230658ce2c2754d076728", + "589388059a8c0314c5ce718b", + "595979485ec6a95e86a58c8d", + "5841206219d291325678ca90", + "58563650804be1058523da55", + "58564084804be1058523e116", + "58636467712e27614690661f", + "58647495712e27614690f36d", + "58654563712e276146918643", + "58664251712e276146923738", + "588084032366dd5d06e59e82", + "588159582366dd5d06e66877", + "5890279190414422fbea9734", + "5890641690414422fbeabbe7", + "585203546789802282f2aaf5", + ] + + # Validation set sequences after filtering + self.val_split_scenes = [ + "00000000000000000000000a", + "5a4a38dad38c8a075495b5d2", + "5a489fb1c7dab83a7d7b1070", + "5a572fd9fc597b0478a81d14", + "5a588a8193ac3d233f77fbca", + "5aa0f478a9efce63548c1cb4", + "5ae2e9c5fe405c5076abc6b2", + "5b2c67b5e0878c381608b8d8", + "5b21e18c58e2823a67a10dd8", + "5b864d850d072a699b32f4ae", + "5b4933abf2b5f44e95de482a", + "5b37189a35304b6f75e7583e", + "5bc5f0e896b66a2cd8f9bd36", + "5bccd6beca24970bce448134", + "5bf26cbbd43923194854b270", + "5bf18642c50e6f7f8bdbd492", + "5bfc9d5aec61ca1dd69132a2", + "5bff3c5cfe0ea555e6bcbf3a", + "5c1f33f1d33e1f2e4aa6dda4", + "5c34529873a8df509ae57b58", + "58a186444a4d262a170ae3ae", + "58f7f7299f5b5647873cb110", + "59acd2f4b891807f439c8992", + "567a0fb0a825d2fb79ac9a20", + "584ad76bfe3cb463906ce6dc", + "584c58b77072670e72c03990", + "586c48329d1b5e34c2838e80", + "586df9849d1b5e34c28506de", + "588e0d8c90414422fbe8f8b2", + "589c300f7dc3d323d5577926", + "590f91851225725be9e25d4e", + "5889e344ec4d5a1c088e59be", + "5898b31cc9dccc22987b82ec", + "5947b62af1b45630bd0c2a02", + "58598db2b338a62ad500bc38", + "58669c02712e27614692851a", + "58790c82ce911104a3467c88", + "58897f62c02346100f4b8ee6", + "588305ed0db9bf59bf8a8c80", + "588457b8932ba84fbed69942", + "5862388b712e2761468f84aa", + "5880675a2366dd5d06e570ca", + "5890523090414422fbeab3f0", + ] + + +class TartanAirV2Splits: + """ + This class contains the information about the splits of the TartanAir V2 dataset. + """ + + def __init__(self): + """ + Splits of environments with unique geometry selected based on TartanVO & UFM splits. + """ + # Apart from the below 2 splits, all other TAv2 scenes are in the train split + # Val split + self.val_split_scenes = ["EndofTheWorld", "HongKong", "WesternDesertTown"] + + # Test split + self.test_split_scenes = [ + "DesertGasStation", + "OldScandinavia", + "PolarSciFi", + "Sewerage", + "Supermarket", + ] + + +class MegaDepthSplits: + """ + This class contains the information about the splits of the MegaDepth dataset. + """ + + def __init__(self): + """ + Validation split is based on scenes used in DUSt3R. + """ + self.val_split_scenes = ["0015_0", "0015_1", "0022_0"] + + +class SpringSplits: + """ + This class contains the information about the splits of the Spring dataset. + """ + + def __init__(self): + self.val_split_scenes = ["0013", "0023", "0037"] + + +class MPSDSplits: + """ + This class contains the information about the splits of the MPSD dataset. + """ + + def __init__(self): + """ + Train & Validation split numpy files containing folder names are generated during preprocessing of MPSD dataset. + Load the numpy files to get the list of scenes in the train & validation split. + A 95% (Train) & 5% (Validation) split is used. + """ + self.train_split_scenes = "load_numpy_file_with_train_scenes" + self.val_split_scenes = "load_numpy_file_with_val_scenes" + + +class ScanNetPPSplits: + """ + This class contains the information about the splits of the ScanNetPP dataset. + """ + + def __init__(self): + """ + Validation & Test split only contains scenes from ScanNet++V2 to prevent data leak with other methods such as DUSt3R during benchmarking. + + Following logic was used to generate the splits: + # Select 80%, 10%, 10% of the scenes for train, val, test respectively from ScanNet++ V2 (~300 scene subset; excluding V1 scenes) + snpp_v2_test_scenes = np.random.choice( + snpp_v2_processed_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False + ) + remaining_scenes = [scene for scene in snpp_v2_processed_scenes if scene not in snpp_v2_test_scenes] + snpp_v2_val_scenes = np.random.choice( + remaining_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False + ) + snpp_v2_train_scenes = [ + scene for scene in remaining_scenes if scene not in snpp_v2_val_scenes and scene not in snpp_v2_test_scenes + ] + """ + # Validation Scenes + self.val_split_scenes = [ + "1c7a683c92", + "2a1b555966", + "3a43c7b8d2", + "4aef651da7", + "06bc6d1b24", + "7f22d5ef1b", + "7f77abce34", + "8ea517a2fc", + "29c7afafed", + "41eb967018", + "77b40ce601", + "086f09d6e3", + "307e3262f1", + "639f2c4d5a", + "894dbd41f1", + "898a7dfd0c", + "2779f8f9e2", + "151178afd7", + "182932a4f3", + "635852d56e", + "9906136b57", + "af112b8903", + "b0f057c684", + "b37177e6c8", + "b119249da7", + "be8367fcbe", + "c8fc01c453", + "e1fb8626c8", + "e2caaaf5b5", + "fe3fc057a1", + ] + + # Test Scenes + self.test_split_scenes = [ + "0e900bcc5c", + "0eba3981c9", + "1cbb105c6a", + "3c8d535d49", + "5d902f1593", + "6bd39ac392", + "6c14d5fd01", + "7c31a42404", + "9bfbc75700", + "13b4efaf62", + "062e5a23a6", + "95b9971d01", + "246fe09e98", + "637a27d04b", + "725b8f0cba", + "413085a827", + "696317583f", + "a4c043ac48", + "a9e4791c7e", + "b0b004c40f", + "c3bc5e82c5", + "c31ebd4b22", + "cba701332a", + "cc5ea8026c", + "cec8312f4e", + "e3b3b0d0c7", + "e667e09fe6", + "eaa6c90310", + "f9397af4cb", + "fb893ffaf3", + ] + + +class DL3DV10KSplits: + """ + This class contains the information about the splits of the DL3DV-10K dataset. + We use the official benchmark split as the val split. + """ + + def __init__(self): + """ + Validation split is based on DL3DV-Benchmark. + """ + self.val_split_scenes = [ + "load https://huggingface.co/datasets/DL3DV/DL3DV-Benchmark/raw/main/benchmark-meta.csv \ + & https://raw.githubusercontent.com/DL3DV-10K/Dataset/main/cache/DL3DV-valid.csv" + ] + + +class ETH3DSplits: + """ + This class contains the information about the splits of the ETH3D dataset. + """ + + def __init__(self): + """ + All scenes are in the test split. + """ + self.test_split_scenes = "all" diff --git a/mapanything/datasets/wai/__init__.py b/mapanything/datasets/wai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/datasets/wai/ase.py b/mapanything/datasets/wai/ase.py new file mode 100644 index 0000000000000000000000000000000000000000..b55249783129fab35ceda619e9b55639eec18b87 --- /dev/null +++ b/mapanything/datasets/wai/ase.py @@ -0,0 +1,294 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +ASE Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class ASEWAI(BaseDataset): + """ + ASE dataset containing large diversity of synthetic indoor scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"ase_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Resize the data to match the desired resolution + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=None, + ) + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="ASE", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/ase", type=str) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = ASEWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 518), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = ASEWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 518), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "ASE_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/blendedmvs.py b/mapanything/datasets/wai/blendedmvs.py new file mode 100644 index 0000000000000000000000000000000000000000..9f133652fe194bb654fc44dd62633d063cd0092e --- /dev/null +++ b/mapanything/datasets/wai/blendedmvs.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +BlendedMVS Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class BlendedMVSWAI(BaseDataset): + """ + BlendedMVS dataset containing object-centric and birds-eye-view scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = False + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"blendedmvs_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth", "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="BlendedMVS", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/blendedmvs", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = BlendedMVSWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 392), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = BlendedMVSWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 392), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "BlendedMVS_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=10, replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/dl3dv.py b/mapanything/datasets/wai/dl3dv.py new file mode 100644 index 0000000000000000000000000000000000000000..166a79a4b6df0e9b60b25e92323dd537e37c3297 --- /dev/null +++ b/mapanything/datasets/wai/dl3dv.py @@ -0,0 +1,356 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +DL3DV Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.cropping import ( + rescale_image_and_other_optional_info, + resize_with_nearest_interpolation_to_match_aspect_ratio, +) +from mapanything.utils.wai.core import load_data, load_frame + + +class DL3DVWAI(BaseDataset): + """ + DL3DV dataset containing over 10k in-the-wild and indoor scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + mvs_confidence_filter_thres: float = 0.25, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + mvs_confidence_filter_thres: Confidence threshold to filter MVS depth. Defaults to 0.25. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self.mvs_confidence_filter_thres = mvs_confidence_filter_thres + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = False + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"dl3dv_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0_mvsa_based" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=[ + "image", + "pred_depth/mvsanywhere", + "pred_mask/moge2", + "depth_confidence/mvsanywhere", + ], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["pred_depth/mvsanywhere"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the dimensions of the original image + img_h, img_w = image.shape[:2] + + # Resize depth to match image aspect ratio while ensuring that depth resolution doesn't increase + depthmap, target_depth_h, target_depth_w = ( + resize_with_nearest_interpolation_to_match_aspect_ratio( + input_data=depthmap, img_h=img_h, img_w=img_w + ) + ) + + # Now resize the image and update intrinsics to match the resized depth + image, _, intrinsics, _ = rescale_image_and_other_optional_info( + image=image, + output_resolution=(target_depth_w, target_depth_h), + depthmap=None, + camera_intrinsics=intrinsics, + ) + image = np.array(image) + + # Get the depth confidence map and mask out the MVS depth + confidence_map = view_data["depth_confidence/mvsanywhere"].numpy() + confidence_mask = ( + confidence_map > self.mvs_confidence_filter_thres + ).astype(int) + confidence_mask = cv2.resize( + confidence_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + depthmap = np.where(confidence_mask, depthmap, 0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="DL3DV", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/dl3dv", type=str) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = DL3DVWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + mvs_confidence_filter_thres=0.25, + resolution=(518, 294), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = DL3DVWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # mvs_confidence_filter_thres=0.25, + # resolution=(518, 294), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "DL3DV_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/dynamicreplica.py b/mapanything/datasets/wai/dynamicreplica.py new file mode 100644 index 0000000000000000000000000000000000000000..148ff2315f63daaa63f10bb6e5a5d501f5159df2 --- /dev/null +++ b/mapanything/datasets/wai/dynamicreplica.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Dynamic Replica Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class DynamicReplicaWAI(BaseDataset): + """ + Dynamic Replica dataset containing synthetic scenes with humans and animals. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"dynamicreplica_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = image[:, :, :3] # RGBA to RGB + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Resize the data to match the desired resolution + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=None, + ) + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="DynamicReplica", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/dynamicreplica", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = DynamicReplicaWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 294), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = DynamicReplicaWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 294), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "DynamicReplica_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/eth3d.py b/mapanything/datasets/wai/eth3d.py new file mode 100644 index 0000000000000000000000000000000000000000..923b98ab6472aa02833d14496edf2ee24146f442 --- /dev/null +++ b/mapanything/datasets/wai/eth3d.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +ETH3D Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class ETH3DWAI(BaseDataset): + """ + ETH3D dataset containing high-quality outdoor and indoor scans of the ETH Zurich campus. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = "test" + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"eth3d_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Resize the data to match the desired resolution + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=None, + ) + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="ETH3D", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/eth3d", type=str) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = ETH3DWAI( + num_views=args.num_of_views, + covisibility_thres=0.025, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 336), + seed=777, + transform="imgnorm", + data_norm_type="dinov2", + ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "ETH3D_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/megadepth.py b/mapanything/datasets/wai/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcb9a8f98b0439ec0a6aec311e6cf2cc251acb0 --- /dev/null +++ b/mapanything/datasets/wai/megadepth.py @@ -0,0 +1,314 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MegaDepth Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class MegaDepthWAI(BaseDataset): + """ + MegaDepth dataset containing outdoor phototourism and in-the-wild scenes. + Also includes Tanks & Temples scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = False + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"megadepth_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth", "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="MegaDepth", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/megadepth", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = MegaDepthWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 336), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = MegaDepthWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 336), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "MegaDepth_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/mpsd.py b/mapanything/datasets/wai/mpsd.py new file mode 100644 index 0000000000000000000000000000000000000000..ee6f48c87853684fe7ccabc6b8747fbcdd40cc18 --- /dev/null +++ b/mapanything/datasets/wai/mpsd.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MPSD Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class MPSDWAI(BaseDataset): + """ + MPSD dataset containing outdoor planet scale metric reconstructions. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"mpsd_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth", "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="MPSD", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/mpsd", type=str) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = MPSDWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.15, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 392), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = MPSDWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.15, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 392), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "MPSD_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/mvs_synth.py b/mapanything/datasets/wai/mvs_synth.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1006e9141541741f4d0f7f7c7f8f54328e33c8 --- /dev/null +++ b/mapanything/datasets/wai/mvs_synth.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MVS Synth Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class MVSSynthWAI(BaseDataset): + """ + MVS Synth dataset containing large diversity of synthetic in-the-wild scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"mvs_synth_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) + non_ambiguous_mask = (depthmap > 0).astype(int) + + # Mask out the outlier depth (horizon depth) + percentile_depth = np.percentile(depthmap, 95) + depthmap[depthmap > percentile_depth] = 0 + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="MVSSynth", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/mvs_synth", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = MVSSynthWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 294), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = MVSSynthWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 294), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "MVSSynth_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/paralleldomain4d.py b/mapanything/datasets/wai/paralleldomain4d.py new file mode 100644 index 0000000000000000000000000000000000000000..f7900ab702508e63d7e89585bdf4b4b179411a57 --- /dev/null +++ b/mapanything/datasets/wai/paralleldomain4d.py @@ -0,0 +1,309 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Parallel Domain 4D Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class ParallelDomain4DWAI(BaseDataset): + """ + Parallel Domain 4D dataset containing large diversity of synthetic AV scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"paralleldomain4d_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = image[:, :, :3] # RGBA to RGB + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) + non_ambiguous_mask = (depthmap > 0).astype(int) + + # Mask out the outlier depth (horizon depth) + percentile_depth = np.percentile(depthmap, 95) + depthmap[depthmap > percentile_depth] = 0 + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="ParallelDomain4D", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/paralleldomain4d", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = ParallelDomain4DWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 392), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = ParallelDomain4DWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 392), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "ParallelDomain4D_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/sailvos3d.py b/mapanything/datasets/wai/sailvos3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e59a2068ee32a44fbdbc9297afebb354f02c2c52 --- /dev/null +++ b/mapanything/datasets/wai/sailvos3d.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +SAIL-VOS 3D Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class SAILVOS3DWAI(BaseDataset): + """ + SAIL-VOS 3D dataset containing large diversity of synthetic in-the-wild cut scenes from GTA. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"sailvos3d_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) + non_ambiguous_mask = (depthmap > 0).astype(int) + + # Mask out the outlier depth (horizon depth) + percentile_depth = np.percentile(depthmap, 95) + depthmap[depthmap > percentile_depth] = 0 + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="SAILVOS3D", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/sailvos3d", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = SAILVOS3DWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 336), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = SAILVOS3DWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 336), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "SAILVOS3D_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/scannetpp.py b/mapanything/datasets/wai/scannetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..1367b461a8456b2d3a73cf2fd4749771e839f5d4 --- /dev/null +++ b/mapanything/datasets/wai/scannetpp.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +ScanNet++V2 Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class ScanNetPPWAI(BaseDataset): + """ + ScanNet++V2 dataset containing large diversity of indoor scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"scannetppv2_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "rendered_depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["rendered_depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Resize the data to match the desired resolution + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=None, + ) + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="ScanNetPP", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/scannetppv2", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = ScanNetPPWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 336), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = ScanNetPPWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 336), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + # dataset = ScanNetPPWAI( + # num_views=args.num_of_views, + # split="test", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 336), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "ScanNetPP_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=10, replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/spring.py b/mapanything/datasets/wai/spring.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b67f2a611c6f293f9619d989ce7ef3f7806649 --- /dev/null +++ b/mapanything/datasets/wai/spring.py @@ -0,0 +1,316 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Spring Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class SpringWAI(BaseDataset): + """ + Spring dataset containing high-quality large-scale in-the-wild scenes with unique animated objects. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"spring_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) # Assumes only npy file in directory is covisibility map + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth", "skymask", "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Get the sky mask and mask out GT depth + sky_mask = view_data["skymask"].numpy().astype(int) + depthmap = np.where(sky_mask, 0, depthmap) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="Spring", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/spring", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = SpringWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 294), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = SpringWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 294), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "Spring_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=10, replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/tav2_wb.py b/mapanything/datasets/wai/tav2_wb.py new file mode 100644 index 0000000000000000000000000000000000000000..9446139fb9188544f4dcd3480e43a2c8c8906c94 --- /dev/null +++ b/mapanything/datasets/wai/tav2_wb.py @@ -0,0 +1,328 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +TartanAirV2-WB Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class TartanAirV2WBWAI(BaseDataset): + """ + TartanAirV2-WB dataset containing vastly-sized in-the-wild synthetic scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"tav2_wb_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth", "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Mask out the outlier depth caused due to transparent windows in TartanAirV2 + percentile_depth = np.percentile(depthmap, 95) + depthmap[depthmap > percentile_depth] = 0 + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="TartanAirV2WB", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/tav2_wb", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = TartanAirV2WBWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 518), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = TartanAirV2WBWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 518), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + # dataset = TartanAirV2WBWAI( + # num_views=args.num_of_views, + # split="test", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 518), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "TartanAirV2WB_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/unrealstereo4k.py b/mapanything/datasets/wai/unrealstereo4k.py new file mode 100644 index 0000000000000000000000000000000000000000..f16f9b226d15ebc4273dc824e0c5885c636ea365 --- /dev/null +++ b/mapanything/datasets/wai/unrealstereo4k.py @@ -0,0 +1,309 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +UnrealStereo4K Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class UnrealStereo4KWAI(BaseDataset): + """ + UnrealStereo4K dataset containing synthetic in-the-wild scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"unrealstereo4k_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = image[:, :, :3] # RGBA to RGB + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) + non_ambiguous_mask = (depthmap > 0).astype(int) + + # Mask out the outlier depth (horizon depth) + percentile_depth = np.percentile(depthmap, 95) + depthmap[depthmap > percentile_depth] = 0 + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="UnrealStereo4K", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/unrealstereo4k", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = UnrealStereo4KWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 294), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = UnrealStereo4KWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 294), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "UnrealStereo4K_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/models/__init__.py b/mapanything/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..498d19c825297a091df9e067cae63c2de68587ef --- /dev/null +++ b/mapanything/models/__init__.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Model Factory for MapAnything +""" + +import importlib.util +import logging +import warnings + +import numpy as np +from omegaconf import DictConfig, OmegaConf + +# Core models that are always available +from mapanything.models.mapanything import ( + MapAnything, + MapAnythingAblations, + ModularDUSt3R, +) + +# Suppress DINOv2 warnings +logging.getLogger("dinov2").setLevel(logging.WARNING) +warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning) +warnings.filterwarnings( + "ignore", message="xFormers is not available", category=UserWarning +) + + +def resolve_special_float(value): + if value == "inf": + return np.inf + elif value == "-inf": + return -np.inf + else: + raise ValueError(f"Unknown special float value: {value}") + + +def init_model( + model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False +): + """ + Initialize a model using OmegaConf configuration. + + Args: + model_str (str): Name of the model class to create. + model_config (DictConfig): OmegaConf model configuration. + torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub. + """ + if not OmegaConf.has_resolver("special_float"): + OmegaConf.register_new_resolver("special_float", resolve_special_float) + model_dict = OmegaConf.to_container(model_config, resolve=True) + model = model_factory( + model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict + ) + + return model + + +# Define model configurations with import paths +MODEL_CONFIGS = { + # Core models + "mapanything": { + "class": MapAnything, + }, + "mapanything_ablations": { + "class": MapAnythingAblations, + }, + "modular_dust3r": { + "class": ModularDUSt3R, + }, + # External models + "anycalib": { + "module": "mapanything.models.external.anycalib", + "class_name": "AnyCalibWrapper", + }, + "dust3r": { + "module": "mapanything.models.external.dust3r", + "class_name": "DUSt3RBAWrapper", + }, + "mast3r": { + "module": "mapanything.models.external.mast3r", + "class_name": "MASt3RSGAWrapper", + }, + "moge": { + "module": "mapanything.models.external.moge", + "class_name": "MoGeWrapper", + }, + "must3r": { + "module": "mapanything.models.external.must3r", + "class_name": "MUSt3RWrapper", + }, + "pi3": { + "module": "mapanything.models.external.pi3", + "class_name": "Pi3Wrapper", + }, + "pow3r": { + "module": "mapanything.models.external.pow3r", + "class_name": "Pow3RWrapper", + }, + "pow3r_ba": { + "module": "mapanything.models.external.pow3r", + "class_name": "Pow3RBAWrapper", + }, + "vggt": { + "module": "mapanything.models.external.vggt", + "class_name": "VGGTWrapper", + }, + # Add other model classes here +} + + +def check_module_exists(module_path): + """ + Check if a module can be imported without actually importing it. + + Args: + module_path (str): The path to the module to check. + + Returns: + bool: True if the module can be imported, False otherwise. + """ + return importlib.util.find_spec(module_path) is not None + + +def model_factory(model_str: str, **kwargs): + """ + Model factory for MapAnything. + + Args: + model_str (str): Name of the model to create. + **kwargs: Additional keyword arguments to pass to the model constructor. + + Returns: + nn.Module: An instance of the specified model. + """ + if model_str not in MODEL_CONFIGS: + raise ValueError( + f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}" + ) + + model_config = MODEL_CONFIGS[model_str] + + # Handle core models directly + if "class" in model_config: + model_class = model_config["class"] + # Handle external models with dynamic imports + elif "module" in model_config: + module_path = model_config["module"] + class_name = model_config["class_name"] + + # Check if the module can be imported + if not check_module_exists(module_path): + raise ImportError( + f"Model '{model_str}' requires module '{module_path}' which is not installed. " + f"Please install the corresponding submodule or package." + ) + + # Dynamically import the module and get the class + try: + module = importlib.import_module(module_path) + model_class = getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ImportError( + f"Failed to import {class_name} from {module_path}: {str(e)}" + ) + else: + raise ValueError(f"Invalid model configuration for {model_str}") + + print(f"Initializing {model_class} with kwargs: {kwargs}") + if model_str != "org_dust3r": + return model_class(**kwargs) + else: + eval_str = kwargs.get("model_eval_str", None) + return eval(eval_str) + + +def get_available_models() -> list: + """ + Get a list of available models in MapAnything. + + Returns: + list: A list of available model names. + """ + return list(MODEL_CONFIGS.keys()) + + +__all__ = ["model_factory", "get_available_models"] diff --git a/mapanything/models/__pycache__/__init__.cpython-312.pyc b/mapanything/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b74f127c03480cea652fc8e052daf0be46f9738 Binary files /dev/null and b/mapanything/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/mapanything/models/external/README.md b/mapanything/models/external/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3fab6d3b5b2ab74546420402e2b41a9eb33def92 --- /dev/null +++ b/mapanything/models/external/README.md @@ -0,0 +1,5 @@ +# External Model Code for Benchmarking & Re-Training + +This directory contains external model code that we use to train and benchmark external models fairly. These libraries are not part of the core MapAnything codebase and are included for only benchmarking purposes. The code in this directory is licensed under the same license as the source code from which it was derived, unless otherwise specified. + +The open-source Apache 2.0 License of MapAnything does not apply to these libraries. diff --git a/mapanything/models/external/__init__.py b/mapanything/models/external/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/anycalib/__init__.py b/mapanything/models/external/anycalib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d1507bd374fe968d000ea4534f5d82f0b895478d --- /dev/null +++ b/mapanything/models/external/anycalib/__init__.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for AnyCalib +""" + +import torch +from anycalib import AnyCalib + +from mapanything.utils.geometry import get_rays_in_camera_frame + + +class AnyCalibWrapper(torch.nn.Module): + def __init__( + self, + name, + model_id="anycalib_pinhole", + **kwargs, + ): + super().__init__() + self.name = name + self.model_id = model_id + + # Initialize the model + self.model = AnyCalib(model_id=self.model_id) + + def forward(self, views): + """ + Forward pass wrapper for AnyCalib. + + Assumption: + - The number of input views is 1. + - The output camera model is pinhole (fx, fy, cx, cy). + This can be relaxed by not hardcoding the cam_id. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Length of the list should be 1. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["identity"] + + Returns: + List[dict]: A list containing the final outputs for the single view. Length of the list will be 1. + """ + # Check that the number of input views is 1 + assert len(views) == 1, "AnyCalib only supports 1 input view." + + # Get input shape of the images and batch size per view + _, _, height, width = views[0]["img"].shape + + # Check the data norm type + # AnyCalib expects a normalized image but without the DINOv2 mean and std applied ("identity") + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "identity", ( + "AnyCalib expects a normalized image but without the DINOv2 mean and std applied" + ) + + # Run AnyCalib inference + # Corresponding batched output dictionary: + # { + # "intrinsics": List[(D_i,) tensors] for each camera model "i" at the original input resolution, + # "fov_field": (B, N, 2) tensor with the regressed FoV field by the network. N≈320^2 (resolution close to the one seen during training), + # "tangent_coords": alias for "fov_field", + # "rays": (B, N, 3) tensor with the corresponding (via the exponential map) ray directions in the camera frame (x right, y down, z forward), + # "pred_size": (H, W) tuple with the image size used by the network. It can be used e.g. for resizing the FoV/ray fields to the original image size. + # } + # For "pinhole" camera model, the intrinsics are (fx, fy, cx, cy). + model_outputs = self.model.predict(views[0]["img"], cam_id="pinhole") + + # Convert the list of intrinsics to a tensor + intrinsics = [] + for intrinsics_per_sample in model_outputs["intrinsics"]: + pred_fx, pred_fy, pred_cx, pred_cy = intrinsics_per_sample + intrinsics_per_sample = torch.tensor( + [ + [pred_fx, 0, pred_cx], + [0, pred_fy, pred_cy], + [0, 0, 1], + ], + device=views[0]["img"].device, + ) + intrinsics.append(intrinsics_per_sample) + + # Convert the list of intrinsics to a tensor of size (batch_size_per_view, 3, 3) + intrinsics = torch.stack(intrinsics) + + # Get the ray directions + with torch.autocast("cuda", enabled=False): + _, ray_directions = get_rays_in_camera_frame( + intrinsics, height, width, normalize_to_unit_sphere=True + ) + + # Return the output in MapAnything format + res = [{"ray_directions": ray_directions, "intrinsics": intrinsics}] + + return res diff --git a/mapanything/models/external/dinov2/__init__.py b/mapanything/models/external/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd11e735b9c3a2ae2b1475ca163537e274fae76 --- /dev/null +++ b/mapanything/models/external/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/mapanything/models/external/dinov2/hub/__init__.py b/mapanything/models/external/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40afb43678d1db842a67445d79260c338a1c1ab5 --- /dev/null +++ b/mapanything/models/external/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/mapanything/models/external/dinov2/hub/backbones.py b/mapanything/models/external/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..3445f144304801154558f457ab041fd39be67743 --- /dev/null +++ b/mapanything/models/external/dinov2/hub/backbones.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from mapanything.models.external.dinov2.hub.utils import ( + _DINOV2_BASE_URL, + _make_dinov2_model_name, +) + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name( + arch_name, patch_size, num_register_tokens + ) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitb14( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitl14( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitg14( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/mapanything/models/external/dinov2/hub/utils.py b/mapanything/models/external/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..faf964ecf7a76fc0186d073725c50d63c602aef6 --- /dev/null +++ b/mapanything/models/external/dinov2/hub/utils.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name( + arch_name: str, patch_size: int, num_register_tokens: int = 0 +) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list( + itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]) + ) + output = F.pad(x, pads) + return output diff --git a/mapanything/models/external/dinov2/layers/__init__.py b/mapanything/models/external/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eca24b7215c4bfa9e333d69369f0f69565ca1a2a --- /dev/null +++ b/mapanything/models/external/dinov2/layers/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mapanything.models.external.dinov2.layers.dino_head import DINOHead # noqa +from mapanything.models.external.dinov2.layers.mlp import Mlp # noqa +from mapanything.models.external.dinov2.layers.patch_embed import PatchEmbed # noqa +from mapanything.models.external.dinov2.layers.swiglu_ffn import ( + SwiGLUFFN, # noqa + SwiGLUFFNFused, # noqa +) +from mapanything.models.external.dinov2.layers.block import NestedTensorBlock # noqa +from mapanything.models.external.dinov2.layers.attention import MemEffAttention # noqa diff --git a/mapanything/models/external/dinov2/layers/attention.py b/mapanything/models/external/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b1beccc7a9f00f823957d3f4e6d51040666f0fc9 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/attention.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os + +from torch import nn, Tensor + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/mapanything/models/external/dinov2/layers/block.py b/mapanything/models/external/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..723f6868d3aa9eb84aacf8695071f27b27d49494 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/block.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Any, Callable, Dict, List, Tuple + +import torch +from torch import nn, Tensor + +from mapanything.models.external.dinov2.layers.attention import ( + Attention, + MemEffAttention, +) +from mapanything.models.external.dinov2.layers.drop_path import DropPath +from mapanything.models.external.dinov2.layers.layer_scale import LayerScale +from mapanything.models.external.dinov2.layers.mlp import Mlp + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, index_select_cat, scaled_index_add + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + x_plus_residual = scaled_index_add( + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/mapanything/models/external/dinov2/layers/dino_head.py b/mapanything/models/external/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..01b9823b0e5db9950e5f2b8dfc9e3439f7da2bf2 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/dino_head.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp( + nlayers, + in_dim, + bottleneck_dim, + hidden_dim=hidden_dim, + use_bn=use_bn, + bias=mlp_bias, + ) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp( + nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True +): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/mapanything/models/external/dinov2/layers/drop_path.py b/mapanything/models/external/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..04cb47af065ec3cfdbc8f59854efaa976ea717d5 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/drop_path.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/mapanything/models/external/dinov2/layers/layer_scale.py b/mapanything/models/external/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..b32da3bd74a028795f5ee628d2636a45b756b074 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/layer_scale.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import nn, Tensor + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/mapanything/models/external/dinov2/layers/mlp.py b/mapanything/models/external/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6d19f53d8562d07d559fe6db93d45b8286420720 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import nn, Tensor + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/mapanything/models/external/dinov2/layers/patch_embed.py b/mapanything/models/external/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..493774d038c9ee7f0f63b05f80561fa61321e2b7 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/patch_embed.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +import torch.nn as nn +from torch import Tensor + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, ( + f"Input image height {H} is not a multiple of patch height {patch_H}" + ) + assert W % patch_W == 0, ( + f"Input image width {W} is not a multiple of patch width: {patch_W}" + ) + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/mapanything/models/external/dinov2/layers/swiglu_ffn.py b/mapanything/models/external/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..c91ef5238dd7b0da9c3bca1ffe9a4f67e81ce224 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional + +import torch.nn.functional as F +from torch import nn, Tensor + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/mapanything/models/external/dinov2/models/__init__.py b/mapanything/models/external/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87f48dae91bb29ef0d256487fda2a16652dd2a3d --- /dev/null +++ b/mapanything/models/external/dinov2/models/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +import mapanything.models.external.dinov2.models.vision_transformer as vits + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model( + cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size + ) diff --git a/mapanything/models/external/dinov2/models/vision_transformer.py b/mapanything/models/external/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..32f7032990f214946e95aaae43d4fd6a86b1c982 --- /dev/null +++ b/mapanything/models/external/dinov2/models/vision_transformer.py @@ -0,0 +1,448 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import math +from functools import partial +from typing import Callable, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.utils.checkpoint import checkpoint + +from mapanything.models.external.dinov2.layers import ( + MemEffAttention, + Mlp, + NestedTensorBlock as Block, + PatchEmbed, + SwiGLUFFNFused, +) +from mapanything.models.external.pi3.layers.attention import FlashAttention + +# logger = logging.getLogger("dinov2") + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if ffn_layer == "mlp": + # logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + # logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + # logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + attn_class=FlashAttention, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append( + [nn.Identity()] * i + blocks_list[i : i + chunksize] + ) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [ + self.prepare_tokens_with_masks(x, masks) + for x, masks in zip(x_list, masks_list) + ] + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), ( + f"only {len(output)} / {len(blocks_to_take)} blocks found" + ) + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), ( + f"only {len(output)} / {len(blocks_to_take)} blocks found" + ) + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/mapanything/models/external/dinov2/utils/__init__.py b/mapanything/models/external/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40afb43678d1db842a67445d79260c338a1c1ab5 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/mapanything/models/external/dinov2/utils/cluster.py b/mapanything/models/external/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..502a7552612fe3f4146cec43f197c16be6c54365 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/cluster.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type( + cluster_type: Optional[ClusterType] = None, +) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path( + cluster_type: Optional[ClusterType] = None, +) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, + num_gpus_per_node: int, + cluster_type: Optional[ClusterType] = None, + **kwargs, +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/mapanything/models/external/dinov2/utils/config.py b/mapanything/models/external/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..87cbfd5848a165b82f3756b7ad8fff71ed4d2366 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/config.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import math +import os + +import dinov2.distributed as distributed +from dinov2.configs import dinov2_default_config +from dinov2.logging import setup_logging +from dinov2.utils import utils +from omegaconf import OmegaConf + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt( + cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0 + ) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info( + "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())) + ) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/mapanything/models/external/dinov2/utils/dtype.py b/mapanything/models/external/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..5a91f3f084ba6e1c1db19708cddf7e4389a87c01 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/dtype.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), ( + f"Expected an instance of nunpy dtype, got {type(dtype)}" + ) + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/mapanything/models/external/dinov2/utils/param_groups.py b/mapanything/models/external/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..df3df118869b3e09214951f0bad4fed32ae79bc7 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/param_groups.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +from collections import defaultdict + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate( + name, + lr_decay_rate=1.0, + num_layers=12, + force_is_backbone=False, + chunked_blocks=False, +): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, + lr_decay_rate, + num_layers=n_blocks, + force_is_backbone=n_blocks > 0, + chunked_blocks=chunked_blocks, + ) + d = { + "params": param, + "is_last_layer": False, + "lr_multiplier": decay_rate, + "wd_multiplier": 1.0, + "name": name, + } + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info( + f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""" + ) + + return all_param_groups + + +def fuse_params_groups( + all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer") +): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/mapanything/models/external/dinov2/utils/utils.py b/mapanything/models/external/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6afe42656020daf57db783da8d76363b4b8e72a2 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/utils.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + +# logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url( + pretrained_weights, map_location="cpu" + ) + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + # logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + _ = model.load_state_dict(state_dict, strict=False) + # logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__( + self, + base_value, + final_value, + total_iters, + warmup_iters=0, + start_warmup_value=0, + freeze_iters=0, + ): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * ( + 1 + np.cos(np.pi * iters / len(iters)) + ) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/mapanything/models/external/dust3r/__init__.py b/mapanything/models/external/dust3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bb39ddbcbeff5deefaefe69c570353e8e4bdb6 --- /dev/null +++ b/mapanything/models/external/dust3r/__init__.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for DUSt3R +""" + +import warnings + +import torch +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.image_pairs import make_pairs +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo # noqa + +from mapanything.models.external.vggt.utils.rotation import mat_to_quat +from mapanything.utils.geometry import ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + convert_z_depth_to_depth_along_ray, + depthmap_to_camera_frame, + get_rays_in_camera_frame, +) + +inf = float("inf") + + +def load_model(model_path, device, verbose=True): + if verbose: + print("Loading model from", model_path) + ckpt = torch.load(model_path, map_location="cpu", weights_only=False) + args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") + if "landscape_only" not in args: + args = args[:-1] + ", landscape_only=False)" + else: + args = args.replace(" ", "").replace( + "landscape_only=True", "landscape_only=False" + ) + assert "landscape_only=False" in args + if verbose: + print(f"Instantiating: {args}") + try: + net = eval(args) + except NameError: + net = AsymmetricCroCo3DStereo( + enc_depth=24, + dec_depth=12, + enc_embed_dim=1024, + dec_embed_dim=768, + enc_num_heads=16, + dec_num_heads=12, + pos_embed="RoPE100", + patch_embed_cls="PatchEmbedDust3R", + img_size=(512, 512), + head_type="dpt", + output_mode="pts3d", + depth_mode=("exp", -inf, inf), + conf_mode=("exp", 1, inf), + landscape_only=False, + ) + s = net.load_state_dict(ckpt["model"], strict=False) + if verbose: + print(s) + return net.to(device) + + +class DUSt3RBAWrapper(torch.nn.Module): + def __init__( + self, + name, + ckpt_path, + scene_graph="complete", + inference_batch_size=32, + global_optim_schedule="cosine", + global_optim_lr=0.01, + global_optim_niter=300, + **kwargs, + ): + super().__init__() + self.name = name + self.ckpt_path = ckpt_path + self.scene_graph = scene_graph + self.inference_batch_size = inference_batch_size + self.global_optim_schedule = global_optim_schedule + self.global_optim_lr = global_optim_lr + self.global_optim_niter = global_optim_niter + + # Init the model and load the checkpoint + self.model = load_model(self.ckpt_path, device="cpu") + + # Init the global aligner mode + self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer + + def forward(self, views): + """ + Forward pass wrapper for DUSt3R using the global aligner. + + Assumption: + - The batch size of input views is 1. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys, where B is the batch size and is 1: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + + Returns: + List[dict]: A list containing the final outputs for the input views. + """ + # Check the batch size of input views + batch_size_per_view, _, height, width = views[0]["img"].shape + device = views[0]["img"].device + num_views = len(views) + assert batch_size_per_view == 1, ( + f"Batch size of input views should be 1, but got {batch_size_per_view}." + ) + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "DUSt3R expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Convert the input views to the expected input format + images = [] + for view in views: + images.append( + dict( + img=view["img"], + idx=len(images), + instance=str(len(images)), + ) + ) + + # Make image pairs and run inference pair-wise + pairs = make_pairs( + images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + output = inference( + pairs, + self.model, + device, + batch_size=self.inference_batch_size, + verbose=False, + ) + + # Global optimization + with torch.enable_grad(): + scene = global_aligner( + output, device=device, mode=self.global_aligner_mode, verbose=False + ) + _ = scene.compute_global_alignment( + init="mst", + niter=self.global_optim_niter, + schedule=self.global_optim_schedule, + lr=self.global_optim_lr, + ) + + # Make sure scene is not None + if scene is None: + raise RuntimeError("Global optimization failed.") + + # Get the predictions + intrinsics = scene.get_intrinsics() + c2w_poses = scene.get_im_poses() + depths = scene.get_depthmaps() + + # Convert the output to the MapAnything format + with torch.autocast("cuda", enabled=False): + res = [] + for view_idx in range(num_views): + # Get the current view predictions + curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0) + curr_view_pose = c2w_poses[view_idx].unsqueeze(0) + curr_view_depth_z = depths[view_idx].unsqueeze(0) + + # Convert the pose to quaternions and translation + curr_view_cam_translations = curr_view_pose[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) + + # Get the camera frame pointmaps + curr_view_pts3d_cam, _ = depthmap_to_camera_frame( + curr_view_depth_z, curr_view_intrinsic + ) + + # Convert the z depth to depth along ray + curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( + curr_view_depth_z, curr_view_intrinsic + ) + curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) + + # Get the ray directions on the unit sphere in the camera frame + _, curr_view_ray_dirs = get_rays_in_camera_frame( + curr_view_intrinsic, height, width, normalize_to_unit_sphere=True + ) + + # Get the pointmaps + curr_view_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + curr_view_ray_dirs, + curr_view_depth_along_ray, + curr_view_cam_translations, + curr_view_cam_quats, + ) + ) + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d, + "pts3d_cam": curr_view_pts3d_cam, + "ray_directions": curr_view_ray_dirs, + "depth_along_ray": curr_view_depth_along_ray, + "cam_trans": curr_view_cam_translations, + "cam_quats": curr_view_cam_quats, + } + ) + + return res diff --git a/mapanything/models/external/mast3r/__init__.py b/mapanything/models/external/mast3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b064255791b95ffaefa804f4c580cafe37d68300 --- /dev/null +++ b/mapanything/models/external/mast3r/__init__.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for MASt3R + Sparse GA +""" + +import os +import tempfile +import warnings + +import torch +from dust3r.image_pairs import make_pairs +from mast3r.cloud_opt.sparse_ga import sparse_global_alignment +from mast3r.model import load_model + +from mapanything.models.external.vggt.utils.rotation import mat_to_quat +from mapanything.utils.geometry import ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + convert_z_depth_to_depth_along_ray, + depthmap_to_camera_frame, + get_rays_in_camera_frame, +) + + +class MASt3RSGAWrapper(torch.nn.Module): + def __init__( + self, + name, + ckpt_path, + cache_dir, + scene_graph="complete", + sparse_ga_lr1=0.07, + sparse_ga_niter1=300, + sparse_ga_lr2=0.01, + sparse_ga_niter2=300, + sparse_ga_optim_level="refine+depth", + sparse_ga_shared_intrinsics=False, + sparse_ga_matching_conf_thr=5.0, + **kwargs, + ): + super().__init__() + self.name = name + self.ckpt_path = ckpt_path + self.cache_dir = cache_dir + self.scene_graph = scene_graph + self.sparse_ga_lr1 = sparse_ga_lr1 + self.sparse_ga_niter1 = sparse_ga_niter1 + self.sparse_ga_lr2 = sparse_ga_lr2 + self.sparse_ga_niter2 = sparse_ga_niter2 + self.sparse_ga_optim_level = sparse_ga_optim_level + self.sparse_ga_shared_intrinsics = sparse_ga_shared_intrinsics + self.sparse_ga_matching_conf_thr = sparse_ga_matching_conf_thr + + # Init the model and load the checkpoint + self.model = load_model(self.ckpt_path, device="cpu") + + def forward(self, views): + """ + Forward pass wrapper for MASt3R using the sparse global aligner. + + Assumption: + - The batch size of input views is 1. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys, where B is the batch size and is 1: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + "label" (list): ["scene_name"] + "instance" (list): ["image_name"] + + Returns: + List[dict]: A list containing the final outputs for the input views. + """ + # Check the batch size of input views + batch_size_per_view, _, height, width = views[0]["img"].shape + device = views[0]["img"].device + num_views = len(views) + assert batch_size_per_view == 1, ( + f"Batch size of input views should be 1, but got {batch_size_per_view}." + ) + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "MASt3R expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Convert the input views to the expected input format + images = [] + image_paths = [] + for view in views: + images.append( + dict( + img=view["img"].cpu(), + idx=len(images), + instance=str(len(images)), + true_shape=torch.tensor(view["img"].shape[-2:])[None] + .repeat(batch_size_per_view, 1) + .numpy(), + ) + ) + view_name = os.path.join(view["label"][0], view["instance"][0]) + image_paths.append(view_name) + + # Make image pairs and run inference + # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation) + pairs = make_pairs( + images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True + ) + with torch.enable_grad(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + tempfile.mkdtemp(dir=self.cache_dir) + scene = sparse_global_alignment( + image_paths, + pairs, + self.cache_dir, + self.model, + lr1=self.sparse_ga_lr1, + niter1=self.sparse_ga_niter1, + lr2=self.sparse_ga_lr2, + niter2=self.sparse_ga_niter2, + device=device, + opt_depth="depth" in self.sparse_ga_optim_level, + shared_intrinsics=self.sparse_ga_shared_intrinsics, + matching_conf_thr=self.sparse_ga_matching_conf_thr, + verbose=False, + ) + + # Make sure scene is not None + if scene is None: + raise RuntimeError("Global optimization failed.") + + # Get the predictions + intrinsics = scene.intrinsics + c2w_poses = scene.get_im_poses() + _, depths, _ = scene.get_dense_pts3d() + + # Convert the output to the MapAnything format + with torch.autocast("cuda", enabled=False): + res = [] + for view_idx in range(num_views): + # Get the current view predictions + curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0) + curr_view_pose = c2w_poses[view_idx].unsqueeze(0) + curr_view_depth_z = ( + depths[view_idx].reshape((height, width)).unsqueeze(0) + ) + + # Convert the pose to quaternions and translation + curr_view_cam_translations = curr_view_pose[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) + + # Get the camera frame pointmaps + curr_view_pts3d_cam, _ = depthmap_to_camera_frame( + curr_view_depth_z, curr_view_intrinsic + ) + + # Convert the z depth to depth along ray + curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( + curr_view_depth_z, curr_view_intrinsic + ) + curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) + + # Get the ray directions on the unit sphere in the camera frame + _, curr_view_ray_dirs = get_rays_in_camera_frame( + curr_view_intrinsic, height, width, normalize_to_unit_sphere=True + ) + + # Get the pointmaps + curr_view_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + curr_view_ray_dirs, + curr_view_depth_along_ray, + curr_view_cam_translations, + curr_view_cam_quats, + ) + ) + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d, + "pts3d_cam": curr_view_pts3d_cam, + "ray_directions": curr_view_ray_dirs, + "depth_along_ray": curr_view_depth_along_ray, + "cam_trans": curr_view_cam_translations, + "cam_quats": curr_view_cam_quats, + } + ) + + return res diff --git a/mapanything/models/external/moge/__init__.py b/mapanything/models/external/moge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b3de1d11134902026514dca26320a666f8a802 --- /dev/null +++ b/mapanything/models/external/moge/__init__.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for MoGe +""" + +import torch + +from mapanything.models.external.moge.models.v1 import MoGeModel as MoGeModelV1 +from mapanything.models.external.moge.models.v2 import MoGeModel as MoGeModelV2 + + +class MoGeWrapper(torch.nn.Module): + def __init__( + self, + name, + model_string="Ruicheng/moge-2-vitl", + torch_hub_force_reload=False, + load_custom_ckpt=False, + custom_ckpt_path=None, + ): + super().__init__() + self.name = name + self.model_string = model_string + self.torch_hub_force_reload = torch_hub_force_reload + self.load_custom_ckpt = load_custom_ckpt + self.custom_ckpt_path = custom_ckpt_path + + # Mapping of MoGe model version to checkpoint strings + self.moge_model_map = { + "v1": ["Ruicheng/moge-vitl"], + "v2": [ + "Ruicheng/moge-2-vits-normal", + "Ruicheng/moge-2-vitb-normal", + "Ruicheng/moge-2-vitl-normal", + "Ruicheng/moge-2-vitl", + ], + } + + # Initialize the model + if self.model_string in self.moge_model_map["v1"]: + self.model = MoGeModelV1.from_pretrained(self.model_string) + elif self.model_string in self.moge_model_map["v2"]: + self.model = MoGeModelV2.from_pretrained(self.model_string) + else: + raise ValueError( + f"Invalid model string: {self.model_string}. Valid strings are: {self.moge_model_map}" + ) + + # Load custom checkpoint if requested + if self.load_custom_ckpt: + print(f"Loading checkpoint from {self.custom_ckpt_path} ...") + assert self.custom_ckpt_path is not None, ( + "custom_ckpt_path must be provided if load_custom_ckpt is set to True" + ) + custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False) + print(self.model.load_state_dict(custom_ckpt, strict=True)) + del custom_ckpt # in case it occupies memory + + def forward(self, views): + """ + Forward pass wrapper for MoGe-2. + The predicted MoGe-2 mask is not applied to the outputs. + The number of tokens for inference is determined by the image shape. + + Assumption: + - The number of input views is 1. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Length of the list should be 1. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["identity"] + + Returns: + List[dict]: A list containing the final outputs for the single view. Length of the list will be 1. + """ + # Check that the number of input views is 1 + assert len(views) == 1, "MoGe only supports 1 input view." + + # Get input shape of the images, number of tokens for inference, and batch size per view + _, _, height, width = views[0]["img"].shape + num_tokens = int(height // 14) * int(width // 14) + + # Check the data norm type + # MoGe expects a normalized image but without the DINOv2 mean and std applied ("identity") + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "identity", ( + "MoGe expects a normalized image but without the DINOv2 mean and std applied" + ) + + # Run MoGe inference + # Output dict contains: "points", "depth", "mask", "intrinsics", "normal" (based on model config) + model_outputs = self.model.infer( + image=views[0]["img"], num_tokens=num_tokens, apply_mask=False + ) + + # Get the ray directions and depth along ray + with torch.autocast("cuda", enabled=False): + depth_along_ray = torch.norm(model_outputs["points"], dim=-1, keepdim=True) + ray_directions = model_outputs["points"] / depth_along_ray + + # Convert the output to MapAnything format + result_dict = { + "pts3d": model_outputs["points"], + "pts3d_cam": model_outputs["points"], + "depth_z": model_outputs["depth"].unsqueeze(-1), + "intrinsics": model_outputs["intrinsics"], + "non_ambiguous_mask": model_outputs["mask"], + "ray_directions": ray_directions, + "depth_along_ray": depth_along_ray, + } + res = [result_dict] + + return res diff --git a/mapanything/models/external/moge/models/modules.py b/mapanything/models/external/moge/models/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..7a55ea3a93d6a554c0923b6796db3ce489ad73dd --- /dev/null +++ b/mapanything/models/external/moge/models/modules.py @@ -0,0 +1,493 @@ +import functools +import importlib +import itertools +from typing import List, Literal, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mapanything.models.external.dinov2.models.vision_transformer import ( + DinoVisionTransformer, +) +from mapanything.models.external.moge.models.utils import ( + wrap_dinov2_attention_with_sdpa, + wrap_module_with_gradient_checkpointing, +) + + +class ResidualConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = None, + hidden_channels: int = None, + kernel_size: int = 3, + padding_mode: str = "replicate", + activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu", + in_norm: Literal[ + "group_norm", "layer_norm", "instance_norm", "none" + ] = "layer_norm", + hidden_norm: Literal[ + "group_norm", "layer_norm", "instance_norm" + ] = "group_norm", + ): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation == "relu": + activation_cls = nn.ReLU + elif activation == "leaky_relu": + activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2) + elif activation == "silu": + activation_cls = nn.SiLU + elif activation == "elu": + activation_cls = nn.ELU + else: + raise ValueError(f"Unsupported activation function: {activation}") + + self.layers = nn.Sequential( + ( + nn.GroupNorm(in_channels // 32, in_channels) + if in_norm == "group_norm" + else ( + nn.GroupNorm(1, in_channels) + if in_norm == "layer_norm" + else ( + nn.InstanceNorm2d(in_channels) + if in_norm == "instance_norm" + else nn.Identity() + ) + ) + ), + activation_cls(), + nn.Conv2d( + in_channels, + hidden_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode=padding_mode, + ), + ( + nn.GroupNorm(hidden_channels // 32, hidden_channels) + if hidden_norm == "group_norm" + else ( + nn.GroupNorm(1, hidden_channels) + if hidden_norm == "layer_norm" + else ( + nn.InstanceNorm2d(hidden_channels) + if hidden_norm == "instance_norm" + else nn.Identity() + ) + ) + ), + activation_cls(), + nn.Conv2d( + hidden_channels, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode=padding_mode, + ), + ) + + self.skip_connection = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class DINOv2Encoder(nn.Module): + "Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]." + + backbone: DinoVisionTransformer + image_mean: torch.Tensor + image_std: torch.Tensor + dim_features: int + + def __init__( + self, + backbone: str, + intermediate_layers: Union[int, List[int]], + dim_out: int, + **deprecated_kwargs, + ): + super(DINOv2Encoder, self).__init__() + + self.intermediate_layers = intermediate_layers + + # Load the backbone + self.hub_loader = getattr( + importlib.import_module( + "mapanything.models.external.dinov2.hub.backbones", __package__ + ), + backbone, + ) + self.backbone_name = backbone + self.backbone = self.hub_loader(pretrained=False) + + self.dim_features = self.backbone.blocks[0].attn.qkv.in_features + self.num_features = ( + intermediate_layers + if isinstance(intermediate_layers, int) + else len(intermediate_layers) + ) + + self.output_projections = nn.ModuleList( + [ + nn.Conv2d( + in_channels=self.dim_features, + out_channels=dim_out, + kernel_size=1, + stride=1, + padding=0, + ) + for _ in range(self.num_features) + ] + ) + + self.register_buffer( + "image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + ) + self.register_buffer( + "image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + ) + + @property + def onnx_compatible_mode(self): + return getattr(self, "_onnx_compatible_mode", False) + + @onnx_compatible_mode.setter + def onnx_compatible_mode(self, value: bool): + self._onnx_compatible_mode = value + self.backbone.onnx_compatible_mode = value + + def init_weights(self): + pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict() + self.backbone.load_state_dict(pretrained_backbone_state_dict) + + def enable_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def enable_pytorch_native_sdpa(self): + for i in range(len(self.backbone.blocks)): + wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) + + def forward( + self, + image: torch.Tensor, + token_rows: Union[int, torch.LongTensor], + token_cols: Union[int, torch.LongTensor], + return_class_token: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + image_14 = F.interpolate( + image, + (token_rows * 14, token_cols * 14), + mode="bilinear", + align_corners=False, + antialias=not self.onnx_compatible_mode, + ) + image_14 = (image_14 - self.image_mean) / self.image_std + + # Get intermediate layers from the backbone + features = self.backbone.get_intermediate_layers( + image_14, n=self.intermediate_layers, return_class_token=True + ) + + # Project features to the desired dimensionality + x = torch.stack( + [ + proj( + feat.permute(0, 2, 1) + .unflatten(2, (token_rows, token_cols)) + .contiguous() + ) + for proj, (feat, clstoken) in zip(self.output_projections, features) + ], + dim=1, + ).sum(dim=1) + + if return_class_token: + return x, features[-1][1] + else: + return x + + +class Resampler(nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + type_: Literal[ + "pixel_shuffle", + "nearest", + "bilinear", + "conv_transpose", + "pixel_unshuffle", + "avg_pool", + "max_pool", + ], + scale_factor: int = 2, + ): + if type_ == "pixel_shuffle": + nn.Sequential.__init__( + self, + nn.Conv2d( + in_channels, + out_channels * (scale_factor**2), + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + nn.PixelShuffle(scale_factor), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + for i in range(1, scale_factor**2): + self[0].weight.data[i :: scale_factor**2] = self[0].weight.data[ + 0 :: scale_factor**2 + ] + self[0].bias.data[i :: scale_factor**2] = self[0].bias.data[ + 0 :: scale_factor**2 + ] + elif type_ in ["nearest", "bilinear"]: + nn.Sequential.__init__( + self, + nn.Upsample( + scale_factor=scale_factor, + mode=type_, + align_corners=False if type_ == "bilinear" else None, + ), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + elif type_ == "conv_transpose": + nn.Sequential.__init__( + self, + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=scale_factor, + stride=scale_factor, + ), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1] + elif type_ == "pixel_unshuffle": + nn.Sequential.__init__( + self, + nn.PixelUnshuffle(scale_factor), + nn.Conv2d( + in_channels * (scale_factor**2), + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + elif type_ == "avg_pool": + nn.Sequential.__init__( + self, + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor), + ) + elif type_ == "max_pool": + nn.Sequential.__init__( + self, + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor), + ) + else: + raise ValueError(f"Unsupported resampler type: {type_}") + + +class MLP(nn.Sequential): + def __init__(self, dims: Sequence[int]): + nn.Sequential.__init__( + self, + *itertools.chain( + *[ + (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True)) + for dim_in, dim_out in zip(dims[:-2], dims[1:-1]) + ] + ), + nn.Linear(dims[-2], dims[-1]), + ) + + +class ConvStack(nn.Module): + def __init__( + self, + dim_in: List[Optional[int]], + dim_res_blocks: List[int], + dim_out: List[Optional[int]], + resamplers: Union[ + Literal[ + "pixel_shuffle", + "nearest", + "bilinear", + "conv_transpose", + "pixel_unshuffle", + "avg_pool", + "max_pool", + ], + List, + ], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_in_norm: Literal[ + "layer_norm", "group_norm", "instance_norm", "none" + ] = "layer_norm", + res_block_hidden_norm: Literal[ + "layer_norm", "group_norm", "instance_norm", "none" + ] = "group_norm", + activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu", + ): + super().__init__() + self.input_blocks = nn.ModuleList( + [ + ( + nn.Conv2d( + dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0 + ) + if dim_in_ is not None + else nn.Identity() + ) + for dim_in_, dim_res_block_ in zip( + ( + dim_in + if isinstance(dim_in, Sequence) + else itertools.repeat(dim_in) + ), + dim_res_blocks, + ) + ] + ) + self.resamplers = nn.ModuleList( + [ + Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler) + for i, (dim_prev, dim_succ, resampler) in enumerate( + zip( + dim_res_blocks[:-1], + dim_res_blocks[1:], + ( + resamplers + if isinstance(resamplers, Sequence) + else itertools.repeat(resamplers) + ), + ) + ) + ] + ) + self.res_blocks = nn.ModuleList( + [ + nn.Sequential( + *( + ResidualConvBlock( + dim_res_block_, + dim_res_block_, + dim_times_res_block_hidden * dim_res_block_, + activation=activation, + in_norm=res_block_in_norm, + hidden_norm=res_block_hidden_norm, + ) + for _ in range( + num_res_blocks[i] + if isinstance(num_res_blocks, list) + else num_res_blocks + ) + ) + ) + for i, dim_res_block_ in enumerate(dim_res_blocks) + ] + ) + self.output_blocks = nn.ModuleList( + [ + ( + nn.Conv2d( + dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0 + ) + if dim_out_ is not None + else nn.Identity() + ) + for dim_out_, dim_res_block_ in zip( + ( + dim_out + if isinstance(dim_out, Sequence) + else itertools.repeat(dim_out) + ), + dim_res_blocks, + ) + ] + ) + + def enable_gradient_checkpointing(self): + for i in range(len(self.resamplers)): + self.resamplers[i] = wrap_module_with_gradient_checkpointing( + self.resamplers[i] + ) + for i in range(len(self.res_blocks)): + for j in range(len(self.res_blocks[i])): + self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing( + self.res_blocks[i][j] + ) + + def forward(self, in_features: List[torch.Tensor]): + out_features = [] + for i in range(len(self.res_blocks)): + feature = self.input_blocks[i](in_features[i]) + if i == 0: + x = feature + elif feature is not None: + x = x + feature + x = self.res_blocks[i](x) + out_features.append(self.output_blocks[i](x)) + if i < len(self.res_blocks) - 1: + x = self.resamplers[i](x) + return out_features diff --git a/mapanything/models/external/moge/models/utils.py b/mapanything/models/external/moge/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9075226ebc4bf05ce614af287e13f0a9fa84b059 --- /dev/null +++ b/mapanything/models/external/moge/models/utils.py @@ -0,0 +1,477 @@ +import inspect +from functools import partial, wraps +from numbers import Number +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def wrap_module_with_gradient_checkpointing(module: nn.Module): + from torch.utils.checkpoint import checkpoint + + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + +def unwrap_module_with_gradient_checkpointing(module: nn.Module): + module.__class__ = module.__class__._restore_cls + + +def wrap_dinov2_attention_with_sdpa(module: nn.Module): + assert torch.__version__ >= "2.0", "SDPA requires PyTorch 2.0 or later" + + class _AttentionWrapper(module.__class__): + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) # (3, B, H, N, C // H) + + q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + module.__class__ = _AttentionWrapper + return module + + +def sync_ddp_hook( + state, bucket: torch.distributed.GradBucket +) -> torch.futures.Future[torch.Tensor]: + group_to_use = torch.distributed.group.WORLD + world_size = group_to_use.size() + grad = bucket.buffer() + grad.div_(world_size) + torch.distributed.all_reduce(grad, group=group_to_use) + fut = torch.futures.Future() + fut.set_result(grad) + return fut + + +def normalized_view_plane_uv( + width: int, + height: int, + aspect_ratio: float = None, + dtype: torch.dtype = None, + device: torch.device = None, +) -> torch.Tensor: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio**2) ** 0.5 + span_y = 1 / (1 + aspect_ratio**2) ** 0.5 + + u = torch.linspace( + -span_x * (width - 1) / width, + span_x * (width - 1) / width, + width, + dtype=dtype, + device=device, + ) + v = torch.linspace( + -span_y * (height - 1) / height, + span_y * (height - 1) / height, + height, + dtype=dtype, + device=device, + ) + u, v = torch.meshgrid(u, v, indexing="xy") + uv = torch.stack([u, v], dim=-1) + return uv + + +def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal" + from scipy.optimize import least_squares + + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[:, None] + f = (xy_proj * uv).sum() / np.square(xy_proj).sum() + err = (f * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method="lm") + optim_shift = solution["x"].squeeze().astype(np.float32) + + xy_proj = xy / (z + optim_shift)[:, None] + optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum() + + return optim_shift, optim_focal + + +def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift" + from scipy.optimize import least_squares + + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[:, None] + err = (focal * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method="lm") + optim_shift = solution["x"].squeeze().astype(np.float32) + + return optim_shift + + +def recover_focal_shift( + points: torch.Tensor, + mask: torch.Tensor = None, + focal: torch.Tensor = None, + downsample_size: Tuple[int, int] = (64, 64), +): + """ + Recover the depth map and FoV from a point map with unknown z shift and focal. + + Note that it assumes: + - the optical center is at the center of the map + - the map is undistorted + - the map is isometric in the x and y directions + + ### Parameters: + - `points: torch.Tensor` of shape (..., H, W, 3) + - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps. + + ### Returns: + - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map + - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space + """ + shape = points.shape + height, width = points.shape[-3], points.shape[-2] + + points = points.reshape(-1, *shape[-3:]) + mask = None if mask is None else mask.reshape(-1, *shape[-3:-1]) + focal = focal.reshape(-1) if focal is not None else None + uv = normalized_view_plane_uv( + width, height, dtype=points.dtype, device=points.device + ) # (H, W, 2) + + points_lr = F.interpolate( + points.permute(0, 3, 1, 2), downsample_size, mode="nearest" + ).permute(0, 2, 3, 1) + uv_lr = ( + F.interpolate( + uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode="nearest" + ) + .squeeze(0) + .permute(1, 2, 0) + ) + mask_lr = ( + None + if mask is None + else F.interpolate( + mask.to(torch.float32).unsqueeze(1), downsample_size, mode="nearest" + ).squeeze(1) + > 0 + ) + + uv_lr_np = uv_lr.cpu().numpy() + points_lr_np = points_lr.detach().cpu().numpy() + focal_np = focal.cpu().numpy() if focal is not None else None + mask_lr_np = None if mask is None else mask_lr.cpu().numpy() + optim_shift, optim_focal = [], [] + for i in range(points.shape[0]): + points_lr_i_np = ( + points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]] + ) + uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]] + if uv_lr_i_np.shape[0] < 2: + optim_focal.append(1) + optim_shift.append(0) + continue + if focal is None: + optim_shift_i, optim_focal_i = solve_optimal_focal_shift( + uv_lr_i_np, points_lr_i_np + ) + optim_focal.append(float(optim_focal_i)) + else: + optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i]) + optim_shift.append(float(optim_shift_i)) + optim_shift = torch.tensor( + optim_shift, device=points.device, dtype=points.dtype + ).reshape(shape[:-3]) + + if focal is None: + optim_focal = torch.tensor( + optim_focal, device=points.device, dtype=points.dtype + ).reshape(shape[:-3]) + else: + optim_focal = focal.reshape(shape[:-3]) + + return optim_focal, optim_shift + + +def suppress_traceback(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + e.__traceback__ = e.__traceback__.tb_next.tb_next + raise + + return wrapper + + +def get_device(args, kwargs): + device = None + for arg in list(args) + list(kwargs.values()): + if isinstance(arg, torch.Tensor): + if device is None: + device = arg.device + elif device != arg.device: + raise ValueError("All tensors must be on the same device.") + return device + + +def get_args_order(func, args, kwargs): + """ + Get the order of the arguments of a function. + """ + names = inspect.getfullargspec(func).args + names_idx = {name: i for i, name in enumerate(names)} + args_order = [] + kwargs_order = {} + for name, arg in kwargs.items(): + if name in names: + kwargs_order[name] = names_idx[name] + names.remove(name) + for i, arg in enumerate(args): + if i < len(names): + args_order.append(names_idx[names[i]]) + return args_order, kwargs_order + + +def broadcast_args(args, kwargs, args_dim, kwargs_dim): + spatial = [] + for arg, arg_dim in zip( + args + list(kwargs.values()), args_dim + list(kwargs_dim.values()) + ): + if isinstance(arg, torch.Tensor) and arg_dim is not None: + arg_spatial = arg.shape[: arg.ndim - arg_dim] + if len(arg_spatial) > len(spatial): + spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial + for j in range(len(arg_spatial)): + if spatial[-j] < arg_spatial[-j]: + if spatial[-j] == 1: + spatial[-j] = arg_spatial[-j] + else: + raise ValueError("Cannot broadcast arguments.") + for i, arg in enumerate(args): + if isinstance(arg, torch.Tensor) and args_dim[i] is not None: + args[i] = torch.broadcast_to( + arg, [*spatial, *arg.shape[arg.ndim - args_dim[i] :]] + ) + for key, arg in kwargs.items(): + if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: + kwargs[key] = torch.broadcast_to( + arg, [*spatial, *arg.shape[arg.ndim - kwargs_dim[key] :]] + ) + return args, kwargs, spatial + + +@suppress_traceback +def batched(*dims): + """ + Decorator that allows a function to be called with batched arguments. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, device=torch.device("cpu"), **kwargs): + args = list(args) + # get arguments dimensions + args_order, kwargs_order = get_args_order(func, args, kwargs) + args_dim = [dims[i] for i in args_order] + kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} + # convert to torch tensor + device = get_device(args, kwargs) or device + for i, arg in enumerate(args): + if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: + args[i] = torch.tensor(arg, device=device) + for key, arg in kwargs.items(): + if ( + isinstance(arg, (Number, list, tuple)) + and kwargs_dim[key] is not None + ): + kwargs[key] = torch.tensor(arg, device=device) + # broadcast arguments + args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) + for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): + if isinstance(arg, torch.Tensor) and arg_dim is not None: + args[i] = arg.reshape([-1, *arg.shape[arg.ndim - arg_dim :]]) + for key, arg in kwargs.items(): + if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: + kwargs[key] = arg.reshape( + [-1, *arg.shape[arg.ndim - kwargs_dim[key] :]] + ) + # call function + results = func(*args, **kwargs) + type_results = type(results) + results = list(results) if isinstance(results, (tuple, list)) else [results] + # restore spatial dimensions + for i, result in enumerate(results): + results[i] = result.reshape([*spatial, *result.shape[1:]]) + if type_results is tuple: + results = tuple(results) + elif type_results is list: + results = list(results) + else: + results = results[0] + return results + + return wrapper + + return decorator + + +def image_uv( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + device: torch.device = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + """ + Get image space UV grid, ranging in [0, 1]. + + >>> image_uv(10, 10): + [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + + Args: + width (int): image width + height (int): image height + + Returns: + torch.Tensor: shape (height, width, 2) + """ + if left is None: + left = 0 + if top is None: + top = 0 + if right is None: + right = width + if bottom is None: + bottom = height + u = torch.linspace( + (left + 0.5) / width, + (right - 0.5) / width, + right - left, + device=device, + dtype=dtype, + ) + v = torch.linspace( + (top + 0.5) / height, + (bottom - 0.5) / height, + bottom - top, + device=device, + dtype=dtype, + ) + u, v = torch.meshgrid(u, v, indexing="xy") + uv = torch.stack([u, v], dim=-1) + + return uv + + +@batched(2, 1, 2, 2) +def unproject_cv( + uv_coord: torch.Tensor, + depth: torch.Tensor = None, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None, +) -> torch.Tensor: + """ + Unproject uv coordinates to 3D view space following the OpenCV convention + + Args: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (torch.Tensor): [..., N] depth value + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + + Returns: + points (torch.Tensor): [..., N, 3] 3d points + """ + assert intrinsics is not None, "intrinsics matrix is required" + points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1) + points = points @ torch.inverse(intrinsics).transpose(-2, -1) + if depth is not None: + points = points * depth[..., None] + if extrinsics is not None: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3] + return points + + +def depth_to_points( + depth: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor = None +): + height, width = depth.shape[-2:] + uv = image_uv(width=width, height=height, dtype=depth.dtype, device=depth.device) + pts = unproject_cv( + uv, + depth, + intrinsics=intrinsics[..., None, :, :], + extrinsics=extrinsics[..., None, :, :] if extrinsics is not None else None, + ) + + return pts + + +@batched(0, 0, 0, 0, 0, 0) +def intrinsics_from_focal_center( + fx: Union[float, torch.Tensor], + fy: Union[float, torch.Tensor], + cx: Union[float, torch.Tensor], + cy: Union[float, torch.Tensor], +) -> torch.Tensor: + """ + Get OpenCV intrinsics matrix + + Args: + focal_x (float | torch.Tensor): focal length in x axis + focal_y (float | torch.Tensor): focal length in y axis + cx (float | torch.Tensor): principal point in x axis + cy (float | torch.Tensor): principal point in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + N = fx.shape[0] + ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device) + zeros, ones = ( + torch.zeros(N, dtype=fx.dtype, device=fx.device), + torch.ones(N, dtype=fx.dtype, device=fx.device), + ) + ret = torch.stack( + [fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1 + ).unflatten(-1, (3, 3)) + return ret diff --git a/mapanything/models/external/moge/models/v1.py b/mapanything/models/external/moge/models/v1.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2f0d1cee59a94d80f9c5a07bc10e40de36a42f --- /dev/null +++ b/mapanything/models/external/moge/models/v1.py @@ -0,0 +1,597 @@ +import importlib +from numbers import Number +from pathlib import Path +from typing import Any, Dict, IO, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +from huggingface_hub import hf_hub_download + +from mapanything.models.external.moge.models.utils import ( + depth_to_points, + intrinsics_from_focal_center, + normalized_view_plane_uv, + recover_focal_shift, + wrap_module_with_gradient_checkpointing, +) + + +class ResidualConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = None, + hidden_channels: int = None, + padding_mode: str = "replicate", + activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu", + norm: Literal["group_norm", "layer_norm"] = "group_norm", + ): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation == "relu": + activation_cls = lambda: nn.ReLU(inplace=True) # noqa + elif activation == "leaky_relu": + activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) # noqa + elif activation == "silu": + activation_cls = lambda: nn.SiLU(inplace=True) # noqa + elif activation == "elu": + activation_cls = lambda: nn.ELU(inplace=True) # noqa + else: + raise ValueError(f"Unsupported activation function: {activation}") + + self.layers = nn.Sequential( + nn.GroupNorm(1, in_channels), + activation_cls(), + nn.Conv2d( + in_channels, + hidden_channels, + kernel_size=3, + padding=1, + padding_mode=padding_mode, + ), + nn.GroupNorm( + hidden_channels // 32 if norm == "group_norm" else 1, hidden_channels + ), + activation_cls(), + nn.Conv2d( + hidden_channels, + out_channels, + kernel_size=3, + padding=1, + padding_mode=padding_mode, + ), + ) + + self.skip_connection = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class Head(nn.Module): + def __init__( + self, + num_features: int, + dim_in: int, + dim_out: List[int], + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + ): + super().__init__() + + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=dim_proj, + kernel_size=1, + stride=1, + padding=0, + ) + for _ in range(num_features) + ] + ) + + self.upsample_blocks = nn.ModuleList( + [ + nn.Sequential( + self._make_upsampler(in_ch + 2, out_ch), + *( + ResidualConvBlock( + out_ch, + out_ch, + dim_times_res_block_hidden * out_ch, + activation="relu", + norm=res_block_norm, + ) + for _ in range(num_res_blocks) + ), + ) + for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample) + ] + ) + + self.output_block = nn.ModuleList( + [ + self._make_output_block( + dim_upsample[-1] + 2, + dim_out_, + dim_times_res_block_hidden, + last_res_blocks, + last_conv_channels, + last_conv_size, + res_block_norm, + ) + for dim_out_ in dim_out + ] + ) + + def _make_upsampler(self, in_channels: int, out_channels: int): + upsampler = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] + return upsampler + + def _make_output_block( + self, + dim_in: int, + dim_out: int, + dim_times_res_block_hidden: int, + last_res_blocks: int, + last_conv_channels: int, + last_conv_size: int, + res_block_norm: Literal["group_norm", "layer_norm"], + ): + return nn.Sequential( + nn.Conv2d( + dim_in, + last_conv_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + *( + ResidualConvBlock( + last_conv_channels, + last_conv_channels, + dim_times_res_block_hidden * last_conv_channels, + activation="relu", + norm=res_block_norm, + ) + for _ in range(last_res_blocks) + ), + nn.ReLU(inplace=True), + nn.Conv2d( + last_conv_channels, + dim_out, + kernel_size=last_conv_size, + stride=1, + padding=last_conv_size // 2, + padding_mode="replicate", + ), + ) + + def forward(self, hidden_states: torch.Tensor, image: torch.Tensor): + img_h, img_w = image.shape[-2:] + patch_h, patch_w = img_h // 14, img_w // 14 + + # Process the hidden states + x = torch.stack( + [ + proj( + feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous() + ) + for proj, (feat, clstoken) in zip(self.projects, hidden_states) + ], + dim=1, + ).sum(dim=1) + + # Upsample stage + # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8) + for i, block in enumerate(self.upsample_blocks): + # UV coordinates is for awareness of image aspect ratio + uv = normalized_view_plane_uv( + width=x.shape[-1], + height=x.shape[-2], + aspect_ratio=img_w / img_h, + dtype=x.dtype, + device=x.device, + ) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + for layer in block: + x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) + + # (patch_h * 8, patch_w * 8) -> (img_h, img_w) + x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) + uv = normalized_view_plane_uv( + width=x.shape[-1], + height=x.shape[-2], + aspect_ratio=img_w / img_h, + dtype=x.dtype, + device=x.device, + ) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + + if isinstance(self.output_block, nn.ModuleList): + output = [ + torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) + for block in self.output_block + ] + else: + output = torch.utils.checkpoint.checkpoint( + self.output_block, x, use_reentrant=False + ) + + return output + + +class MoGeModel(nn.Module): + image_mean: torch.Tensor + image_std: torch.Tensor + + def __init__( + self, + encoder: str = "dinov2_vitb14", + intermediate_layers: Union[int, List[int]] = 4, + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + remap_output: Literal[ + False, True, "linear", "sinh", "exp", "sinh_exp" + ] = "linear", + res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", + num_tokens_range: Tuple[Number, Number] = [1200, 2500], + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + mask_threshold: float = 0.5, + **deprecated_kwargs, + ): + super(MoGeModel, self).__init__() + + if deprecated_kwargs: + # Process legacy arguments + if "trained_area_range" in deprecated_kwargs: + num_tokens_range = [ + deprecated_kwargs["trained_area_range"][0] // 14**2, + deprecated_kwargs["trained_area_range"][1] // 14**2, + ] + del deprecated_kwargs["trained_area_range"] + # warnings.warn( + # f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}" + # ) + + self.encoder = encoder + self.remap_output = remap_output + self.intermediate_layers = intermediate_layers + self.num_tokens_range = num_tokens_range + self.mask_threshold = mask_threshold + + # NOTE: We have copied the DINOv2 code in torchhub to this repository. + # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues. + hub_loader = getattr( + importlib.import_module( + "mapanything.models.external.dinov2.hub.backbones", __package__ + ), + encoder, + ) + self.backbone = hub_loader(pretrained=False) + dim_feature = self.backbone.blocks[0].attn.qkv.in_features + + self.head = Head( + num_features=( + intermediate_layers + if isinstance(intermediate_layers, int) + else len(intermediate_layers) + ), + dim_in=dim_feature, + dim_out=[3, 1], + dim_proj=dim_proj, + dim_upsample=dim_upsample, + dim_times_res_block_hidden=dim_times_res_block_hidden, + num_res_blocks=num_res_blocks, + res_block_norm=res_block_norm, + last_res_blocks=last_res_blocks, + last_conv_channels=last_conv_channels, + last_conv_size=last_conv_size, + ) + + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, Path, IO[bytes]], + model_kwargs: Optional[Dict[str, Any]] = None, + **hf_kwargs, + ) -> "MoGeModel": + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint = torch.load( + pretrained_model_name_or_path, map_location="cpu", weights_only=True + ) + else: + cached_checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs, + ) + checkpoint = torch.load( + cached_checkpoint_path, map_location="cpu", weights_only=True + ) + model_config = checkpoint["model_config"] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint["model"]) + return model + + def init_weights(self): + "Load the backbone with pretrained dinov2 weights from torch hub" + state_dict = torch.hub.load( + "facebookresearch/dinov2", self.encoder, pretrained=True + ).state_dict() + self.backbone.load_state_dict(state_dict) + + def enable_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing( + self.backbone.blocks[i] + ) + + def _remap_points(self, points: torch.Tensor) -> torch.Tensor: + if self.remap_output == "linear": + pass + elif self.remap_output == "sinh": + points = torch.sinh(points) + elif self.remap_output == "exp": + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output == "sinh_exp": + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + return points + + def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: + original_height, original_width = image.shape[-2:] + + # Resize to expected resolution defined by num_tokens + resize_factor = ( + (num_tokens * 14**2) / (original_height * original_width) + ) ** 0.5 + resized_width, resized_height = ( + int(original_width * resize_factor), + int(original_height * resize_factor), + ) + image = F.interpolate( + image, + (resized_height, resized_width), + mode="bicubic", + align_corners=False, + antialias=True, + ) + + # Apply image transformation for DINOv2 + image = (image - self.image_mean) / self.image_std + image_14 = F.interpolate( + image, + (resized_height // 14 * 14, resized_width // 14 * 14), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # Get intermediate layers from the backbone + features = self.backbone.get_intermediate_layers( + image_14, self.intermediate_layers, return_class_token=True + ) + + # Predict points (and mask) + output = self.head(features, image) + points, mask = output + + # Make sure fp32 precision for output + with torch.autocast(device_type=image.device.type, dtype=torch.float32): + # Resize to original resolution + points = F.interpolate( + points, + (original_height, original_width), + mode="bilinear", + align_corners=False, + antialias=False, + ) + mask = F.interpolate( + mask, + (original_height, original_width), + mode="bilinear", + align_corners=False, + antialias=False, + ) + + # Post-process points and mask + points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1) + points = self._remap_points( + points + ) # slightly improves the performance in case of very large output values + + return_dict = {"points": points, "mask": mask} + return return_dict + + # @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + fov_x: Union[Number, torch.Tensor] = None, + resolution_level: int = 9, + num_tokens: int = None, + apply_mask: bool = True, + force_projection: bool = True, + use_fp16: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + User-friendly inference function + + ### Parameters + - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\ + - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None + - `resolution_level`: An integer [0-9] for the resolution level for inference. + The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. + `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details. + - `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. + `resolution_level` will be ignored if `num_tokens` is provided. Default: None + - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True + - `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True + - `use_fp16`: if True, use mixed precision to speed up inference. Default: True + + ### Returns + + A dictionary containing the following keys: + - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). + - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. + - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. + """ + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + image = image.to(dtype=self.dtype, device=self.device) + + original_height, original_width = image.shape[-2:] + aspect_ratio = original_width / original_height + + if num_tokens is None: + min_tokens, max_tokens = self.num_tokens_range + num_tokens = int( + min_tokens + (resolution_level / 9) * (max_tokens - min_tokens) + ) + + with torch.autocast( + device_type=self.device.type, + dtype=torch.float16, + enabled=use_fp16 and self.dtype != torch.float16, + ): + output = self.forward(image, num_tokens) + points, mask = output["points"], output["mask"] + + # Always process the output in fp32 precision + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + points, mask, fov_x = map( + lambda x: x.float() if isinstance(x, torch.Tensor) else x, + [points, mask, fov_x], + ) + + mask_binary = mask > self.mask_threshold + + # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal) + if fov_x is None: + focal, shift = recover_focal_shift(points, mask_binary) + else: + focal = ( + aspect_ratio + / (1 + aspect_ratio**2) ** 0.5 + / torch.tan( + torch.deg2rad( + torch.as_tensor( + fov_x, device=points.device, dtype=points.dtype + ) + / 2 + ) + ) + ) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + fx = focal / 2 * (1 + aspect_ratio**2) ** 0.5 / aspect_ratio + fy = focal / 2 * (1 + aspect_ratio**2) ** 0.5 + intrinsics = intrinsics_from_focal_center(fx, fy, 0.5, 0.5) + depth = points[..., 2] + shift[..., None, None] + + # If projection constraint is forced, recompute the point map using the actual depth map + if force_projection: + points = depth_to_points(depth, intrinsics=intrinsics) + else: + points = ( + points + + torch.stack( + [torch.zeros_like(shift), torch.zeros_like(shift), shift], + dim=-1, + )[..., None, None, :] + ) + + # Apply mask if needed + if apply_mask: + points = torch.where(mask_binary[..., None], points, torch.inf) + depth = torch.where(mask_binary, depth, torch.inf) + + return_dict = { + "points": points, + "intrinsics": intrinsics, + "depth": depth, + "mask": mask_binary, + } + + if omit_batch_dim: + return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} + + return return_dict diff --git a/mapanything/models/external/moge/models/v2.py b/mapanything/models/external/moge/models/v2.py new file mode 100644 index 0000000000000000000000000000000000000000..521295e4ea208a73b0f2ea57e95d87fb90972aa6 --- /dev/null +++ b/mapanything/models/external/moge/models/v2.py @@ -0,0 +1,385 @@ +import warnings +from numbers import Number +from pathlib import Path +from typing import Any, Dict, IO, List, Literal, Optional, Union + +import torch +import torch.amp +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +from huggingface_hub import hf_hub_download + +from mapanything.models.external.moge.models.modules import ( + ConvStack, + DINOv2Encoder, + MLP, +) +from mapanything.models.external.moge.models.utils import ( + depth_to_points, + intrinsics_from_focal_center, + normalized_view_plane_uv, + recover_focal_shift, +) + + +class MoGeModel(nn.Module): + encoder: DINOv2Encoder + neck: ConvStack + points_head: ConvStack + mask_head: ConvStack + scale_head: MLP + onnx_compatible_mode: bool + + def __init__( + self, + encoder: Dict[str, Any], + neck: Dict[str, Any], + points_head: Dict[str, Any] = None, + mask_head: Dict[str, Any] = None, + normal_head: Dict[str, Any] = None, + scale_head: Dict[str, Any] = None, + remap_output: Literal["linear", "sinh", "exp", "sinh_exp"] = "linear", + num_tokens_range: List[int] = [1200, 3600], + **deprecated_kwargs, + ): + super(MoGeModel, self).__init__() + if deprecated_kwargs: + warnings.warn( + f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}" + ) + + self.remap_output = remap_output + self.num_tokens_range = num_tokens_range + + self.encoder = DINOv2Encoder(**encoder) + self.neck = ConvStack(**neck) + if points_head is not None: + self.points_head = ConvStack(**points_head) + if mask_head is not None: + self.mask_head = ConvStack(**mask_head) + if normal_head is not None: + self.normal_head = ConvStack(**normal_head) + if scale_head is not None: + self.scale_head = MLP(**scale_head) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def onnx_compatible_mode(self) -> bool: + return getattr(self, "_onnx_compatible_mode", False) + + @onnx_compatible_mode.setter + def onnx_compatible_mode(self, value: bool): + self._onnx_compatible_mode = value + self.encoder.onnx_compatible_mode = value + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, Path, IO[bytes]], + model_kwargs: Optional[Dict[str, Any]] = None, + **hf_kwargs, + ) -> "MoGeModel": + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `compiled` + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint_path = pretrained_model_name_or_path + else: + checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs, + ) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + + model_config = checkpoint["model_config"] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint["model"], strict=False) + + return model + + def init_weights(self): + self.encoder.init_weights() + + def enable_gradient_checkpointing(self): + self.encoder.enable_gradient_checkpointing() + self.neck.enable_gradient_checkpointing() + for head in ["points_head", "normal_head", "mask_head"]: + if hasattr(self, head): + getattr(self, head).enable_gradient_checkpointing() + + def enable_pytorch_native_sdpa(self): + self.encoder.enable_pytorch_native_sdpa() + + def _remap_points(self, points: torch.Tensor) -> torch.Tensor: + if self.remap_output == "linear": + pass + elif self.remap_output == "sinh": + points = torch.sinh(points) + elif self.remap_output == "exp": + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output == "sinh_exp": + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + return points + + def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: + batch_size, _, img_h, img_w = image.shape + device, dtype = image.device, image.dtype + + aspect_ratio = img_w / img_h + base_h, base_w = ( + int((num_tokens / aspect_ratio) ** 0.5), + int((num_tokens * aspect_ratio) ** 0.5), + ) + num_tokens = base_h * base_w + + # Backbones encoding + features, cls_token = self.encoder( + image, base_h, base_w, return_class_token=True + ) + features = [features, None, None, None, None] + + # Concat UVs for aspect ratio input + for level in range(5): + uv = normalized_view_plane_uv( + width=base_w * 2**level, + height=base_h * 2**level, + aspect_ratio=aspect_ratio, + dtype=dtype, + device=device, + ) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1) + if features[level] is None: + features[level] = uv + else: + features[level] = torch.concat([features[level], uv], dim=1) + + # Shared neck + features = self.neck(features) + + # Heads decoding + points, normal, mask = ( + getattr(self, head)(features)[-1] if hasattr(self, head) else None + for head in ["points_head", "normal_head", "mask_head"] + ) + metric_scale = ( + self.scale_head(cls_token) if hasattr(self, "scale_head") else None + ) + + # Resize + points, normal, mask = ( + ( + F.interpolate( + v, + (img_h, img_w), + mode="bilinear", + align_corners=False, + antialias=False, + ) + if v is not None + else None + ) + for v in [points, normal, mask] + ) + + # Remap output + if points is not None: + points = points.permute(0, 2, 3, 1) + points = self._remap_points( + points + ) # slightly improves the performance in case of very large output values + if normal is not None: + normal = normal.permute(0, 2, 3, 1) + normal = F.normalize(normal, dim=-1) + if mask is not None: + mask = mask.squeeze(1).sigmoid() + if metric_scale is not None: + metric_scale = metric_scale.squeeze(1).exp() + + return_dict = { + "points": points, + "normal": normal, + "mask": mask, + "metric_scale": metric_scale, + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + return return_dict + + # @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + num_tokens: int = None, + resolution_level: int = 9, + force_projection: bool = True, + apply_mask: Literal[False, True, "blend"] = True, + fov_x: Optional[Union[Number, torch.Tensor]] = None, + use_fp16: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + User-friendly inference function + + ### Parameters + - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W) + - `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500. + More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`. + - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True + - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True + - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None + - `use_fp16`: if True, use mixed precision to speed up inference. Default: True + + ### Returns + + A dictionary containing the following keys: + - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). + - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. + - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. + """ + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + image = image.to(dtype=self.dtype, device=self.device) + + original_height, original_width = image.shape[-2:] + aspect_ratio = original_width / original_height + + # Determine the number of base tokens to use + if num_tokens is None: + min_tokens, max_tokens = self.num_tokens_range + num_tokens = int( + min_tokens + (resolution_level / 9) * (max_tokens - min_tokens) + ) + + # Forward pass + with torch.autocast( + device_type=self.device.type, + dtype=torch.float16, + enabled=use_fp16 and self.dtype != torch.float16, + ): + output = self.forward(image, num_tokens=num_tokens) + points, normal, mask, metric_scale = ( + output.get(k, None) for k in ["points", "normal", "mask", "metric_scale"] + ) + + # Always process the output in fp32 precision + points, normal, mask, metric_scale, fov_x = map( + lambda x: x.float() if isinstance(x, torch.Tensor) else x, + [points, normal, mask, metric_scale, fov_x], + ) + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + if mask is not None: + mask_binary = mask > 0.5 + else: + mask_binary = None + + if points is not None: + # Convert affine point map to camera-space. Recover depth and intrinsics from point map. + # NOTE: Focal here is the focal length relative to half the image diagonal + if fov_x is None: + # Recover focal and shift from predicted point map + focal, shift = recover_focal_shift(points, mask_binary) + else: + # Focal is known, recover shift only + focal = ( + aspect_ratio + / (1 + aspect_ratio**2) ** 0.5 + / torch.tan( + torch.deg2rad( + torch.as_tensor( + fov_x, device=points.device, dtype=points.dtype + ) + / 2 + ) + ) + ) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + fx, fy = ( + focal / 2 * (1 + aspect_ratio**2) ** 0.5 / aspect_ratio, + focal / 2 * (1 + aspect_ratio**2) ** 0.5, + ) + intrinsics = intrinsics_from_focal_center(fx, fy, 0.5, 0.5) + points[..., 2] += shift[..., None, None] + if mask_binary is not None: + mask_binary &= ( + points[..., 2] > 0 + ) # in case depth is contains negative values (which should never happen in practice) + depth = points[..., 2].clone() + else: + depth, intrinsics = None, None + + # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics + if force_projection and depth is not None: + points = depth_to_points(depth, intrinsics=intrinsics) + + # Apply metric scale + if metric_scale is not None: + if points is not None: + points *= metric_scale[:, None, None, None] + if depth is not None: + depth *= metric_scale[:, None, None] + + # Apply mask + if apply_mask and mask_binary is not None: + points = ( + torch.where(mask_binary[..., None], points, torch.inf) + if points is not None + else None + ) + depth = ( + torch.where(mask_binary, depth, torch.inf) + if depth is not None + else None + ) + normal = ( + torch.where( + mask_binary[..., None], normal, torch.zeros_like(normal) + ) + if normal is not None + else None + ) + + return_dict = { + "points": points, + "intrinsics": intrinsics, + "depth": depth, + "mask": mask_binary, + "normal": normal, + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + if omit_batch_dim: + return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} + + return return_dict diff --git a/mapanything/models/external/must3r/__init__.py b/mapanything/models/external/must3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022036a34727844f4886b382740747b7e509dc6e --- /dev/null +++ b/mapanything/models/external/must3r/__init__.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for MUSt3R +""" + +import datetime +import os + +import numpy as np +import torch +from dust3r.viz import rgb +from must3r.demo.inference import SceneState +from must3r.engine.inference import inference_multi_ar, postprocess +from must3r.model import get_pointmaps_activation, load_model + +from mapanything.models.external.vggt.utils.rotation import mat_to_quat + + +def must3r_inference( + views, + filelist, + model, + retrieval, + device, + amp, + num_mem_images, + max_bs, + init_num_images=2, + batch_num_views=1, + render_once=False, + is_sequence=False, + viser_server=None, + num_refinements_iterations=2, + verbose=True, +): + if amp == "fp16": + dtype = torch.float16 + elif amp == "bf16": + assert torch.cuda.is_bf16_supported() + dtype = torch.bfloat16 + else: + assert not amp + dtype = torch.float32 + + max_bs = None if max_bs == 0 else max_bs + encoder, decoder = model + pointmaps_activation = get_pointmaps_activation(decoder, verbose=verbose) + + def post_process_function(x): + return postprocess( + x, pointmaps_activation=pointmaps_activation, compute_cam=True + ) + + if verbose: + print("loading images") + time_start = datetime.datetime.now() + nimgs = len(views) + + ellapsed = datetime.datetime.now() - time_start + if verbose: + print(f"loaded in {ellapsed}") + print("running inference") + time_start = datetime.datetime.now() + if viser_server is not None: + viser_server.reset(nimgs) + + imgs = [b["img"].to("cpu") for b in views] + true_shape = [torch.from_numpy(b["true_shape"]).to("cpu") for b in views] + true_shape = torch.stack(true_shape, dim=0) + nimgs = true_shape.shape[0] + + # Use all images as keyframes + keyframes = np.linspace(0, len(imgs) - 1, num_mem_images, dtype=int).tolist() + encoder_precomputed_features = None + + not_keyframes = sorted(set(range(nimgs)).difference(set(keyframes))) + assert (len(keyframes) + len(not_keyframes)) == nimgs + # reorder images + views = [views[i] for i in keyframes] + [views[i] for i in not_keyframes] + imgs = [b["img"].to(device) for b in views] + true_shape = [torch.from_numpy(b["true_shape"]).to(device) for b in views] + filenames = [filelist[i] for i in keyframes + not_keyframes] + img_ids = [torch.tensor(v) for v in keyframes + not_keyframes] + + if encoder_precomputed_features is not None: + x_start, pos_start = encoder_precomputed_features + x = [x_start[i] for i in keyframes] + [x_start[i] for i in not_keyframes] + pos = [pos_start[i] for i in keyframes] + [pos_start[i] for i in not_keyframes] + encoder_precomputed_features = (x, pos) + + mem_batches = [init_num_images] + while (sum_b := sum(mem_batches)) != max(num_mem_images, init_num_images): + size_b = min(batch_num_views, num_mem_images - sum_b) + mem_batches.append(size_b) + + if render_once: + to_render = list(range(num_mem_images, nimgs)) + else: + to_render = None + + with torch.autocast("cuda", dtype=dtype): + x_out_0, x_out = inference_multi_ar( + encoder, + decoder, + imgs, + img_ids, + true_shape, + mem_batches, + max_bs=max_bs, + verbose=verbose, + to_render=to_render, + encoder_precomputed_features=encoder_precomputed_features, + device=device, + preserve_gpu_mem=True, + post_process_function=post_process_function, + viser_server=viser_server, + num_refinements_iterations=num_refinements_iterations, + ) + if to_render is not None: + x_out = x_out_0 + x_out + + ellapsed = datetime.datetime.now() - time_start + if verbose: + print(f"inference in {ellapsed}") + try: + print(str(int(torch.cuda.max_memory_reserved(device) / (1024**2))) + " MB") + except Exception: + pass + + if viser_server is not None: + viser_server.reset_cam_visility() + viser_server.send_message("Finished") + + if verbose: + print("preparing pointcloud") + time_start = datetime.datetime.now() + focals = [] + cams2world = [] + for i in range(nimgs): + focals.append(float(x_out[i]["focal"].cpu())) + cams2world.append(x_out[i]["c2w"].cpu()) + + # x_out to cpu + for i in range(len(x_out)): + for k in x_out[i].keys(): + x_out[i][k] = x_out[i][k].cpu() + + rgbimg = [rgb(imgs[i], true_shape[i]) for i in range(nimgs)] + scene = SceneState(x_out, rgbimg, true_shape, focals, cams2world, filenames) + + ellapsed = datetime.datetime.now() - time_start + if verbose: + print(f"pointcloud prepared in {ellapsed}") + + return scene + + +class MUSt3RWrapper(torch.nn.Module): + def __init__( + self, + name, + ckpt_path, + retrieval_ckpt_path, + img_size=512, + amp="bf16", + max_bs=1, + **kwargs, + ): + super().__init__() + self.name = name + self.ckpt_path = ckpt_path + self.retrieval_ckpt_path = retrieval_ckpt_path + self.amp = amp + self.max_bs = max_bs + + # Init the model and load the checkpoint + self.model = load_model(self.ckpt_path, img_size=512) + + def forward(self, views): + """ + Forward pass wrapper for MUSt3R. + + Assumption: + - The batch size of input views is 1. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys, where B is the batch size and is 1: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + "label" (list): ["scene_name"] + "instance" (list): ["image_name"] + + Returns: + List[dict]: A list containing the final outputs for the input views. + """ + # Check the batch size of input views + batch_size_per_view, _, height, width = views[0]["img"].shape + device = views[0]["img"].device + num_views = len(views) + assert batch_size_per_view == 1, ( + f"Batch size of input views should be 1, but got {batch_size_per_view}." + ) + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "MUSt3R expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Convert the input views to the expected input format + images = [] + image_paths = [] + for view in views: + images.append( + dict( + img=view["img"][0].cpu(), + idx=len(images), + instance=str(len(images)), + true_shape=np.int32([view["img"].shape[-2], view["img"].shape[-1]]), + ) + ) + view_name = os.path.join(view["label"][0], view["instance"][0]) + image_paths.append(view_name) + + # Run MUSt3R inference + scene = must3r_inference( + images, + image_paths, + self.model, + self.retrieval_ckpt_path, + device, + self.amp, + num_views, + self.max_bs, + verbose=False, + ) + + # Make sure scene is not None + if scene is None: + raise RuntimeError("MUSt3R failed.") + + # Get the predictions + predictions = scene.x_out + + # Convert the output to the MapAnything format + with torch.autocast("cuda", enabled=False): + res = [] + for view_idx in range(num_views): + # Get the current view predictions + curr_view_prediction = predictions[view_idx] + curr_view_conf = curr_view_prediction["conf"] + curr_view_pose = curr_view_prediction["c2w"].unsqueeze(0) + + # Convert the pose to quaternions and translation + curr_view_cam_translations = curr_view_pose[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) + + # Get the camera frame pointmaps + curr_view_pts3d_cam = curr_view_prediction["pts3d_local"].unsqueeze(0) + + # Get the depth along ray and ray directions + curr_view_depth_along_ray = torch.norm( + curr_view_pts3d_cam, dim=-1, keepdim=True + ) + curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray + + # Get the pointmaps + curr_view_pts3d = curr_view_prediction["pts3d"].unsqueeze(0) + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d.to(device), + "pts3d_cam": curr_view_pts3d_cam.to(device), + "ray_directions": curr_view_ray_dirs.to(device), + "depth_along_ray": curr_view_depth_along_ray.to(device), + "cam_trans": curr_view_cam_translations.to(device), + "cam_quats": curr_view_cam_quats.to(device), + "conf": curr_view_conf.to(device), + } + ) + + return res diff --git a/mapanything/models/external/pi3/__init__.py b/mapanything/models/external/pi3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3814640daf72b98e699f70a250215cd34039dadf --- /dev/null +++ b/mapanything/models/external/pi3/__init__.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for Pi3 +""" + +import torch + +from mapanything.models.external.pi3.models.pi3 import Pi3 +from mapanything.models.external.vggt.utils.rotation import mat_to_quat + + +class Pi3Wrapper(torch.nn.Module): + def __init__( + self, + name, + torch_hub_force_reload, + load_pretrained_weights=True, + pos_type="rope100", + decoder_size="large", + ): + super().__init__() + self.name = name + self.torch_hub_force_reload = torch_hub_force_reload + + if load_pretrained_weights: + # Load pre-trained weights + if not torch_hub_force_reload: + # Initialize the Pi3 model from huggingface hub cache + print("Loading Pi3 from huggingface cache ...") + self.model = Pi3.from_pretrained( + "yyfz233/Pi3", + ) + else: + # Initialize the Pi3 model + self.model = Pi3.from_pretrained("yyfz233/Pi3", force_download=True) + else: + # Load the Pi3 class + self.model = Pi3( + pos_type=pos_type, + decoder_size=decoder_size, + ) + + # Get the dtype for Pi3 inference + # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) + self.dtype = ( + torch.bfloat16 + if torch.cuda.get_device_capability()[0] >= 8 + else torch.float16 + ) + + def forward(self, views): + """ + Forward pass wrapper for Pi3 + + Assumption: + - All the input views have the same image shape. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["identity"] + + Returns: + List[dict]: A list containing the final outputs for all N views. + """ + # Get input shape of the images, number of views, and batch size per view + batch_size_per_view, _, height, width = views[0]["img"].shape + num_views = len(views) + + # Check the data norm type + # Pi3 expects a normalized image but without the DINOv2 mean and std applied ("identity") + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "identity", ( + "Pi3 expects a normalized image but without the DINOv2 mean and std applied" + ) + + # Concatenate the images to create a single (B, V, C, H, W) tensor + img_list = [view["img"] for view in views] + images = torch.stack(img_list, dim=1) + + # Run the Pi3 aggregator + with torch.autocast("cuda", dtype=self.dtype): + results = self.model(images) + + # Need high precision for transformations + with torch.autocast("cuda", enabled=False): + # Convert the output to MapAnything format + res = [] + for view_idx in range(num_views): + # Get the extrinsics + curr_view_extrinsic = results["camera_poses"][:, view_idx, ...] + curr_view_cam_translations = curr_view_extrinsic[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3]) + + # Get the depth along ray, ray directions, local point cloud & global point cloud + curr_view_pts3d_cam = results["local_points"][:, view_idx, ...] + curr_view_depth_along_ray = torch.norm( + curr_view_pts3d_cam, dim=-1, keepdim=True + ) + curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray + curr_view_pts3d = results["points"][:, view_idx, ...] + + # Get the confidence + curr_view_confidence = results["conf"][:, view_idx, ...] + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d, + "pts3d_cam": curr_view_pts3d_cam, + "ray_directions": curr_view_ray_dirs, + "depth_along_ray": curr_view_depth_along_ray, + "cam_trans": curr_view_cam_translations, + "cam_quats": curr_view_cam_quats, + "conf": curr_view_confidence, + } + ) + + return res diff --git a/mapanything/models/external/pi3/layers/__init__.py b/mapanything/models/external/pi3/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/pi3/layers/attention.py b/mapanything/models/external/pi3/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..28ca9413f3aa61bf55c88c29e734b230844f6bb6 --- /dev/null +++ b/mapanything/models/external/pi3/layers/attention.py @@ -0,0 +1,429 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + + +import os + +import torch +from torch import nn, Tensor +from torch.nn.attention import SDPBackend +from torch.nn.functional import scaled_dot_product_attention + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:, :, i] for i in range(3)] + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FlashAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .transpose(1, 3) + ) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:, :, i] for i in range(3)] + + if q.dtype == torch.bfloat16: + with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v) + else: + with nn.attention.sdpa_kernel( + [SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION] + ): + x = scaled_dot_product_attention(q, k, v) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +""" +Following is written by GPT-4o +""" + + +class CrossAttentionRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + qk_norm: bool = False, + norm_layer: nn.Module = nn.LayerNorm, + rope=None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + # Separate projection layers for query, key, and value + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) + + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias=None, + qpos=None, + kpos=None, + ) -> Tensor: + """ + Args: + query: Tensor of shape (B, N, C), input query + key: Tensor of shape (B, M, C), input key + value: Tensor of shape (B, M, C), input value + attn_bias: Optional tensor for attention bias + Returns: + Tensor of shape (B, N, C), output of cross-attention + """ + B, N, C = query.shape + _, M, _ = key.shape + + # Project query, key, and value + q = ( + self.q_proj(query) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = ( + self.k_proj(key) + .reshape(B, M, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + v = ( + self.v_proj(value) + .reshape(B, M, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + # Scale query + q = q * self.scale + + # Compute attention scores + attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M) + if attn_bias is not None: + attn = attn + attn_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # Compute attention output + x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C) + + # Final projection + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffCrossAttentionRope(CrossAttentionRope): + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias=None, + qpos=None, + kpos=None, + ) -> Tensor: + """ + Args: + query: Tensor of shape (B, N, C), input query + key: Tensor of shape (B, M, C), input key + value: Tensor of shape (B, M, C), input value + attn_bias: Optional tensor for attention bias + Returns: + Tensor of shape (B, N, C), output of cross-attention + """ + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(query, key, value, attn_bias) + + B, N, C = query.shape + _, M, _ = key.shape + + # Project query, key, and value + q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads) + k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads) + v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + # Compute memory-efficient attention + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape(B, N, C) + + # Final projection + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class AttentionRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + qk_norm: bool = False, + norm_layer: nn.Module = nn.LayerNorm, + rope=None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + + self.rope = rope + + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:, :, i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix + # global_valid_id = torch.where(score_matrix > 0) + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FlashAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .transpose(1, 3) + ) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:, :, i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + if q.dtype == torch.bfloat16: + with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v) + else: + with nn.attention.sdpa_kernel( + [SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION] + ): + x = scaled_dot_product_attention(q, k, v) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def get_attn_score(blk_class, x, frame_num, token_length, xpos=None): + x = blk_class.norm1(x) + + B, N, C = x.shape + qkv = blk_class.attn.qkv(x).reshape( + B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads + ) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:, :, i] for i in range(3)] + q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype) + + if blk_class.attn.rope is not None: + q = blk_class.attn.rope(q, xpos) + k = blk_class.attn.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + score = ( + ( + q.permute(0, 2, 1, 3) + * blk_class.attn.scale + @ k.permute(0, 2, 1, 3).transpose(-2, -1) + ) + .sum(dim=1) + .reshape(B, frame_num, token_length, frame_num, token_length) + .mean(dim=[2, 4]) + .sum(-1) + ) + + return score diff --git a/mapanything/models/external/pi3/layers/block.py b/mapanything/models/external/pi3/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..1eac717874a6b10dee909c88c16faa8e23059525 --- /dev/null +++ b/mapanything/models/external/pi3/layers/block.py @@ -0,0 +1,448 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import os +from typing import Any, Callable, Dict, List, Tuple + +import torch +from torch import nn, Tensor + +from mapanything.models.external.dinov2.layers.drop_path import DropPath +from mapanything.models.external.dinov2.layers.layer_scale import LayerScale +from mapanything.models.external.dinov2.layers.mlp import Mlp +from mapanything.models.external.pi3.layers.attention import ( + Attention, + CrossAttentionRope, + MemEffAttention, +) + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, index_select_cat, scaled_index_add + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + x_plus_residual = scaled_index_add( + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + + +class BlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + rope=None, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + rope=rope, + ) + + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, xpos=None) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x), xpos=xpos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +class CrossBlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope, + ffn_layer: Callable[..., nn.Module] = Mlp, + init_values=None, + qk_norm: bool = False, + rope=None, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + rope=rope, + qk_norm=qk_norm, + ) + + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.ls_y = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.norm2 = norm_layer(dim) + self.norm_y = norm_layer(dim) + self.cross_attn = cross_attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + rope=rope, + qk_norm=qk_norm, + ) + + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + bias=ffn_bias, + ) + + def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x), xpos=xpos)) + + def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor: + return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm3(x))) + + x = x + attn_residual_func(x) + y_ = self.norm_y(y) + x = x + cross_attn_residual_func(x, y_) + x = x + ffn_residual_func(x) + + return x diff --git a/mapanything/models/external/pi3/layers/camera_head.py b/mapanything/models/external/pi3/layers/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..37e12354b2e7a5a799c883b12710d6528f39241b --- /dev/null +++ b/mapanything/models/external/pi3/layers/camera_head.py @@ -0,0 +1,106 @@ +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172' +class ResConvBlock(nn.Module): + """ + 1x1 convolution residual block + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.head_skip = ( + nn.Identity() + if self.in_channels == self.out_channels + else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + ) + # self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + # self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + # self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + + # change 1x1 convolution to linear + self.res_conv1 = nn.Linear(self.in_channels, self.out_channels) + self.res_conv2 = nn.Linear(self.out_channels, self.out_channels) + self.res_conv3 = nn.Linear(self.out_channels, self.out_channels) + + def forward(self, res): + x = F.relu(self.res_conv1(res)) + x = F.relu(self.res_conv2(x)) + x = F.relu(self.res_conv3(x)) + res = self.head_skip(res) + x + return res + + +class CameraHead(nn.Module): + def __init__(self, dim=512): + super().__init__() + output_dim = dim + self.res_conv = nn.ModuleList( + [deepcopy(ResConvBlock(output_dim, output_dim)) for _ in range(2)] + ) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.more_mlps = nn.Sequential( + nn.Linear(output_dim, output_dim), + nn.ReLU(), + nn.Linear(output_dim, output_dim), + nn.ReLU(), + ) + self.fc_t = nn.Linear(output_dim, 3) + self.fc_rot = nn.Linear(output_dim, 9) + + def forward(self, feat, patch_h, patch_w): + BN, hw, c = feat.shape + + for i in range(2): + feat = self.res_conv[i](feat) + + # feat = self.avgpool(feat) + feat = self.avgpool( + feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous() + ) ########## + feat = feat.view(feat.size(0), -1) + + feat = self.more_mlps(feat) # [B, D_] + with torch.amp.autocast(device_type="cuda", enabled=False): + out_t = self.fc_t(feat.float()) # [B,3] + out_r = self.fc_rot(feat.float()) # [B,9] + pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device) + + return pose + + def convert_pose_to_4x4(self, B, out_r, out_t, device): + out_r = self.svd_orthogonalize(out_r) # [N,3,3] + pose = torch.zeros((B, 4, 4), device=device) + pose[:, :3, :3] = out_r + pose[:, :3, 3] = out_t + pose[:, 3, 3] = 1.0 + return pose + + def svd_orthogonalize(self, m): + """Convert 9D representation to SO(3) using SVD orthogonalization. + + Args: + m: [BATCH, 3, 3] 3x3 matrices. + + Returns: + [BATCH, 3, 3] SO(3) rotation matrices. + """ + if m.dim() < 3: + m = m.reshape((-1, 3, 3)) + m_transpose = torch.transpose( + torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2 + ) + u, s, v = torch.svd(m_transpose) + det = torch.det(torch.matmul(v, u.transpose(-2, -1))) + # Check orientation reflection. + r = torch.matmul( + torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2), + u.transpose(-2, -1), + ) + return r diff --git a/mapanything/models/external/pi3/layers/pos_embed.py b/mapanything/models/external/pi3/layers/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd500f4b353646c3ba8cafe2598186e8c97031e --- /dev/null +++ b/mapanything/models/external/pi3/layers/pos_embed.py @@ -0,0 +1,190 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + + +import numpy as np +import torch + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if n_cls_token > 0: + pos_embed = np.concatenate( + [np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +# ---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +# ---------------------------------------------------------- + +try: + from models.curope import cuRoPE2D + + RoPE2D = cuRoPE2D +except ImportError: + + class RoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, D, 2).float().to(device) / D) + ) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3) % 2 == 0, ( + "number of dimensions should be a multiple of two" + ) + D = tokens.size(3) // 2 + assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin( + D, int(positions.max()) + 1, tokens.device, tokens.dtype + ) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:, :, 0], cos, sin) + x = self.apply_rope1d(x, positions[:, :, 1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens + + +# patch embedding +class PositionGetter(object): + """return positions of patches""" + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if (h, w) not in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone() + return pos diff --git a/mapanything/models/external/pi3/layers/transformer_head.py b/mapanything/models/external/pi3/layers/transformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d411938709d11c68b6618bff9205d001feb91829 --- /dev/null +++ b/mapanything/models/external/pi3/layers/transformer_head.py @@ -0,0 +1,98 @@ +from functools import partial + +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from mapanything.models.external.dinov2.layers import Mlp +from mapanything.models.external.pi3.layers.attention import FlashAttentionRope +from mapanything.models.external.pi3.layers.block import BlockRope + + +class TransformerDecoder(nn.Module): + def __init__( + self, + in_dim, + out_dim, + dec_embed_dim=512, + depth=5, + dec_num_heads=8, + mlp_ratio=4, + rope=None, + need_project=True, + use_checkpoint=False, + ): + super().__init__() + + self.projects = ( + nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity() + ) + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList( + [ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=None, + qk_norm=False, + # attn_class=MemEffAttentionRope, + attn_class=FlashAttentionRope, + rope=rope, + ) + for _ in range(depth) + ] + ) + + self.linear_out = nn.Linear(dec_embed_dim, out_dim) + + def forward(self, hidden, xpos=None): + hidden = self.projects(hidden) + for i, blk in enumerate(self.blocks): + if self.use_checkpoint and self.training: + hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False) + else: + hidden = blk(hidden, xpos=xpos) + out = self.linear_out(hidden) + return out + + +class LinearPts3d(nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__( + self, + patch_size, + dec_embed_dim, + output_dim=3, + ): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(dec_embed_dim, (output_dim) * self.patch_size**2) + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + # extract 3D points + feat = self.proj(tokens) # B,S,D + feat = feat.transpose(-1, -2).view( + B, -1, H // self.patch_size, W // self.patch_size + ) + feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W + + # permute + norm depth + return feat.permute(0, 2, 3, 1) diff --git a/mapanything/models/external/pi3/models/__init__.py b/mapanything/models/external/pi3/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/pi3/models/pi3.py b/mapanything/models/external/pi3/models/pi3.py new file mode 100644 index 0000000000000000000000000000000000000000..1b79e71bf21e468f966eb11841febf6b23ec7160 --- /dev/null +++ b/mapanything/models/external/pi3/models/pi3.py @@ -0,0 +1,251 @@ +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin + +from mapanything.models.external.dinov2.hub.backbones import dinov2_vitl14_reg +from mapanything.models.external.dinov2.layers import Mlp +from mapanything.models.external.pi3.layers.attention import FlashAttentionRope +from mapanything.models.external.pi3.layers.block import BlockRope +from mapanything.models.external.pi3.layers.camera_head import CameraHead +from mapanything.models.external.pi3.layers.pos_embed import PositionGetter, RoPE2D +from mapanything.models.external.pi3.layers.transformer_head import ( + LinearPts3d, + TransformerDecoder, +) + + +def homogenize_points( + points, +): + """Convert batched points (xyz) to (xyz1).""" + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +class Pi3(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + pos_type="rope100", + decoder_size="large", + ): + super().__init__() + + # ---------------------- + # Encoder + # ---------------------- + self.encoder = dinov2_vitl14_reg(pretrained=False) + self.patch_size = 14 + del self.encoder.mask_token + + # ---------------------- + # Positonal Encoding + # ---------------------- + self.pos_type = pos_type if pos_type is not None else "none" + self.rope = None + if self.pos_type.startswith("rope"): # eg rope100 + if RoPE2D is None: + raise ImportError( + "Cannot find cuRoPE2D, please install it following the README instructions" + ) + freq = float(self.pos_type[len("rope") :]) + self.rope = RoPE2D(freq=freq) + self.position_getter = PositionGetter() + else: + raise NotImplementedError + + # ---------------------- + # Decoder + # ---------------------- + if decoder_size == "small": + dec_embed_dim = 384 + dec_num_heads = 6 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == "base": + dec_embed_dim = 768 + dec_num_heads = 12 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == "large": + dec_embed_dim = 1024 + dec_num_heads = 16 + mlp_ratio = 4 + dec_depth = 36 + else: + raise NotImplementedError + self.decoder = nn.ModuleList( + [ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=0.01, + qk_norm=True, + attn_class=FlashAttentionRope, + rope=self.rope, + ) + for _ in range(dec_depth) + ] + ) + self.dec_embed_dim = dec_embed_dim + + # ---------------------- + # Register_token + # ---------------------- + num_register_tokens = 5 + self.patch_start_idx = num_register_tokens + self.register_token = nn.Parameter( + torch.randn(1, 1, num_register_tokens, self.dec_embed_dim) + ) + nn.init.normal_(self.register_token, std=1e-6) + + # ---------------------- + # Local Points Decoder + # ---------------------- + self.point_decoder = TransformerDecoder( + in_dim=2 * self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, + out_dim=1024, + rope=self.rope, + ) + self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3) + + # ---------------------- + # Conf Decoder + # ---------------------- + self.conf_decoder = deepcopy(self.point_decoder) + self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1) + + # ---------------------- + # Camera Pose Decoder + # ---------------------- + self.camera_decoder = TransformerDecoder( + in_dim=2 * self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, # 8 + out_dim=512, + rope=self.rope, + use_checkpoint=False, + ) + self.camera_head = CameraHead(dim=512) + + # For ImageNet Normalize + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + def decode(self, hidden, N, H, W): + BN, hw, _ = hidden.shape + B = BN // N + + final_output = [] + + hidden = hidden.reshape(B * N, hw, -1) + + register_token = self.register_token.repeat(B, N, 1, 1).reshape( + B * N, *self.register_token.shape[-2:] + ) + + # Concatenate special tokens with patch tokens + hidden = torch.cat([register_token, hidden], dim=1) + hw = hidden.shape[1] + + if self.pos_type.startswith("rope"): + pos = self.position_getter( + B * N, H // self.patch_size, W // self.patch_size, hidden.device + ) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = ( + torch.zeros(B * N, self.patch_start_idx, 2) + .to(hidden.device) + .to(pos.dtype) + ) + pos = torch.cat([pos_special, pos], dim=1) + + for i in range(len(self.decoder)): + blk = self.decoder[i] + + if i % 2 == 0: + pos = pos.reshape(B * N, hw, -1) + hidden = hidden.reshape(B * N, hw, -1) + else: + pos = pos.reshape(B, N * hw, -1) + hidden = hidden.reshape(B, N * hw, -1) + + hidden = blk(hidden, xpos=pos) + + if i + 1 in [len(self.decoder) - 1, len(self.decoder)]: + final_output.append(hidden.reshape(B * N, hw, -1)) + + return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape( + B * N, hw, -1 + ) + + def forward(self, imgs): + imgs = (imgs - self.image_mean) / self.image_std + + B, N, _, H, W = imgs.shape + patch_h, patch_w = H // 14, W // 14 + + # encode by dinov2 + imgs = imgs.reshape(B * N, _, H, W) + hidden = self.encoder(imgs, is_training=True) + + if isinstance(hidden, dict): + hidden = hidden["x_norm_patchtokens"] + + hidden, pos = self.decode(hidden, N, H, W) + + point_hidden = self.point_decoder(hidden, xpos=pos) + conf_hidden = self.conf_decoder(hidden, xpos=pos) + camera_hidden = self.camera_decoder(hidden, xpos=pos) + + with torch.amp.autocast(device_type="cuda", enabled=False): + # local points + point_hidden = point_hidden.float() + ret = self.point_head( + [point_hidden[:, self.patch_start_idx :]], (H, W) + ).reshape(B, N, H, W, -1) + xy, z = ret.split([2, 1], dim=-1) + z = torch.exp(z) + local_points = torch.cat([xy * z, z], dim=-1) + + # confidence + conf_hidden = conf_hidden.float() + conf = self.conf_head( + [conf_hidden[:, self.patch_start_idx :]], (H, W) + ).reshape(B, N, H, W, -1) + + # camera + camera_hidden = camera_hidden.float() + camera_poses = self.camera_head( + camera_hidden[:, self.patch_start_idx :], patch_h, patch_w + ).reshape(B, N, 4, 4) + + # unproject local points using camera poses + points = torch.einsum( + "bnij, bnhwj -> bnhwi", camera_poses, homogenize_points(local_points) + )[..., :3] + + return dict( + points=points, + local_points=local_points, + conf=conf, + camera_poses=camera_poses, + ) diff --git a/mapanything/models/external/pow3r/__init__.py b/mapanything/models/external/pow3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25d9c09f83bb37fda12c767bdb1f8f110463a792 --- /dev/null +++ b/mapanything/models/external/pow3r/__init__.py @@ -0,0 +1,865 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for Pow3R +""" + +import warnings +from copy import deepcopy + +import pow3r.model.blocks # noqa +import roma +import torch +import torch.nn as nn +import tqdm +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.image_pairs import make_pairs +from dust3r.inference import check_if_same_size +from dust3r.model import CroCoNet +from dust3r.patch_embed import get_patch_embed as dust3r_patch_embed +from dust3r.utils.device import collate_with_cat, to_cpu +from dust3r.utils.misc import ( + fill_default_args, + freeze_all_params, + interleave, + is_symmetrized, + transpose_to_landscape, +) +from pow3r.model.blocks import Block, BlockInject, DecoderBlock, DecoderBlockInject, Mlp +from pow3r.model.heads import head_factory +from pow3r.model.inference import ( + add_depth, + add_intrinsics, + add_relpose, +) +from pow3r.model.patch_embed import get_patch_embed + +from mapanything.models.external.vggt.utils.rotation import mat_to_quat +from mapanything.utils.geometry import ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + convert_z_depth_to_depth_along_ray, + depthmap_to_camera_frame, + get_rays_in_camera_frame, +) + + +class Pow3R(CroCoNet): + """Two siamese encoders, followed by two decoders. + The goal is to output 3d points directly, both images in view1's frame + (hence the asymmetry). + """ + + def __init__( + self, + mode="embed", + head_type="linear", + patch_embed_cls="PatchEmbedDust3R", + freeze="none", + landscape_only=True, + **croco_kwargs, + ): + # retrieve all default arguments using python magic + self.croco_args = fill_default_args(croco_kwargs, super().__init__) + super().__init__(**croco_kwargs) + del self.mask_token # useless + del self.prediction_head + + dec_dim, enc_dim = self.decoder_embed.weight.shape + self.enc_embed_dim = enc_dim + self.dec_embed_dim = dec_dim + + self.mode = mode + # additional parameters in the encoder + img_size = self.patch_embed.img_size + patch_size = self.patch_embed.patch_size[0] + self.patch_embed = dust3r_patch_embed( + patch_embed_cls, img_size, patch_size, self.enc_embed_dim + ) + self.patch_embed_rays = get_patch_embed( + patch_embed_cls + "_Mlp", + img_size, + patch_size, + self.enc_embed_dim, + in_chans=3, + ) + self.patch_embed_depth = get_patch_embed( + patch_embed_cls + "_Mlp", + img_size, + patch_size, + self.enc_embed_dim, + in_chans=2, + ) + self.pose_embed = Mlp(12, 4 * dec_dim, dec_dim) + + # additional parameters in the decoder + self.dec_cls = "_cls" in self.mode + self.dec_num_cls = 0 + if self.dec_cls: + # use a CLS token in the decoder only + self.mode = self.mode.replace("_cls", "") + self.cls_token1 = nn.Parameter(torch.zeros((dec_dim,))) + self.cls_token2 = nn.Parameter(torch.zeros((dec_dim,))) + self.dec_num_cls = 1 # affects all blocks + + use_ln = "_ln" in self.mode # TODO remove? + self.patch_ln = nn.LayerNorm(enc_dim) if use_ln else nn.Identity() + self.dec1_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity() + self.dec2_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity() + + self.dec_blocks2 = deepcopy(self.dec_blocks) + + # here we modify some of the blocks + self.replace_some_blocks() + + self.set_downstream_head(head_type, landscape_only, **croco_kwargs) + self.set_freeze(freeze) + + def replace_some_blocks(self): + assert self.mode.startswith("inject") # inject[0,0.5] + NewBlock = BlockInject + DecoderNewBlock = DecoderBlockInject + + all_layers = { + i / n + for i in range(len(self.enc_blocks)) + for n in [len(self.enc_blocks), len(self.dec_blocks)] + } + which_layers = eval(self.mode[self.mode.find("[") :]) or all_layers + assert isinstance(which_layers, (set, list)) + + n = 0 + for i, block in enumerate(self.enc_blocks): + if i / len(self.enc_blocks) in which_layers: + block.__class__ = NewBlock + block.init(self.enc_embed_dim) + n += 1 + else: + block.__class__ = Block + assert n == len(which_layers), breakpoint() + + n = 0 + for i in range(len(self.dec_blocks)): + for blocks in [self.dec_blocks, self.dec_blocks2]: + block = blocks[i] + if i / len(self.dec_blocks) in which_layers: + block.__class__ = DecoderNewBlock + block.init(self.dec_embed_dim) + n += 1 + else: + block.__class__ = DecoderBlock + assert n == 2 * len(which_layers), breakpoint() + + @classmethod + def from_pretrained(cls, pretrained_model_path, **kw): + return _load_model(pretrained_model_path, device="cpu") + + def load_state_dict(self, ckpt, **kw): + # duplicate all weights for the second decoder if not present + new_ckpt = dict(ckpt) + if not any(k.startswith("dec_blocks2") for k in ckpt): + for key, value in ckpt.items(): + if key.startswith("dec_blocks"): + new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value + # remove layers that have different shapes + cur_ckpt = self.state_dict() + for key, val in ckpt.items(): + if key.startswith("downstream_head2.proj"): + if key in cur_ckpt and cur_ckpt[key].shape != val.shape: + print(f" (removing ckpt[{key}] because wrong shape)") + del new_ckpt[key] + return super().load_state_dict(new_ckpt, **kw) + + def set_freeze(self, freeze): # this is for use by downstream models + self.freeze = freeze + to_be_frozen = { + "none": [], + "encoder": [self.patch_embed, self.enc_blocks], + } + freeze_all_params(to_be_frozen[freeze]) + + def set_prediction_head(self, *args, **kwargs): + """No prediction head""" + return + + def set_downstream_head( + self, + head_type, + landscape_only, + patch_size, + img_size, + mlp_ratio, + dec_depth, + **kw, + ): + assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, ( + f"{img_size=} must be multiple of {patch_size=}" + ) + + # split heads if different + heads = head_type.split(";") + assert len(heads) in (1, 2) + head1_type, head2_type = (heads + heads)[:2] + + # allocate heads + self.downstream_head1 = head_factory(head1_type, self) + self.downstream_head2 = head_factory(head2_type, self) + + # magic wrapper + self.head1 = transpose_to_landscape( + self.downstream_head1, activate=landscape_only + ) + self.head2 = transpose_to_landscape( + self.downstream_head2, activate=landscape_only + ) + + def _encode_image(self, image, true_shape, rays=None, depth=None): + # embed the image into patches (x has size B x Npatches x C) + x, pos = self.patch_embed(image, true_shape=true_shape) + + if rays is not None: # B,3,H,W + rays_emb, pos2 = self.patch_embed_rays(rays, true_shape=true_shape) + assert (pos == pos2).all() + if self.mode.startswith("embed"): + x = x + rays_emb + else: + rays_emb = None + + if depth is not None: # B,2,H,W + depth_emb, pos2 = self.patch_embed_depth(depth, true_shape=true_shape) + assert (pos == pos2).all() + if self.mode.startswith("embed"): + x = x + depth_emb + else: + depth_emb = None + + x = self.patch_ln(x) + + # add positional embedding without cls token + assert self.enc_pos_embed is None + + # now apply the transformer encoder and normalization + for blk in self.enc_blocks: + x = blk(x, pos, rays=rays_emb, depth=depth_emb) + + x = self.enc_norm(x) + return x, pos + + def encode_symmetrized(self, view1, view2): + img1 = view1["img"] + img2 = view2["img"] + B = img1.shape[0] + # Recover true_shape when available, otherwise assume that the img shape is the true one + shape1 = view1.get( + "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1) + ) + shape2 = view2.get( + "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1) + ) + # warning! maybe the images have different portrait/landscape orientations + + # privileged information + rays1 = view1.get("known_rays", None) + rays2 = view2.get("known_rays", None) + depth1 = view1.get("known_depth", None) + depth2 = view2.get("known_depth", None) + + if is_symmetrized(view1, view2): + # computing half of forward pass!' + def hsub(x): + return None if x is None else x[::2] + + feat1, pos1 = self._encode_image( + img1[::2], shape1[::2], rays=hsub(rays1), depth=hsub(depth1) + ) + feat2, pos2 = self._encode_image( + img2[::2], shape2[::2], rays=hsub(rays2), depth=hsub(depth2) + ) + + feat1, feat2 = interleave(feat1, feat2) + pos1, pos2 = interleave(pos1, pos2) + else: + feat1, pos1 = self._encode_image(img1, shape1, rays=rays1, depth=depth1) + feat2, pos2 = self._encode_image(img2, shape2, rays=rays2, depth=depth2) + + return (shape1, shape2), (feat1, feat2), (pos1, pos2) + + def _decoder(self, f1, pos1, f2, pos2, relpose1=None, relpose2=None): + final_output = [(f1, f2)] # before projection + + # project to decoder dim + f1 = self.decoder_embed(f1) + f2 = self.decoder_embed(f2) + + # add CLS token for the decoder + if self.dec_cls: + cls1 = self.cls_token1[None, None].expand(len(f1), 1, -1).clone() + cls2 = self.cls_token2[None, None].expand(len(f2), 1, -1).clone() + + if relpose1 is not None: # shape = (B, 4, 4) + pose_emb1 = self.pose_embed(relpose1[:, :3].flatten(1)).unsqueeze(1) + if self.mode.startswith("embed"): + if self.dec_cls: + cls1 = cls1 + pose_emb1 + else: + f1 = f1 + pose_emb1 + else: + pose_emb1 = None + + if relpose2 is not None: # shape = (B, 4, 4) + pose_emb2 = self.pose_embed(relpose2[:, :3].flatten(1)).unsqueeze(1) + if self.mode.startswith("embed"): + if self.dec_cls: + cls2 = cls2 + pose_emb2 + else: + f2 = f2 + pose_emb2 + else: + pose_emb2 = None + + if self.dec_cls: + f1, pos1 = cat_cls(cls1, f1, pos1) + f2, pos2 = cat_cls(cls2, f2, pos2) + + f1 = self.dec1_pre_ln(f1) + f2 = self.dec2_pre_ln(f2) + + final_output.append((f1, f2)) # to be removed later + for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): + # img1 side + f1, _ = blk1( + *final_output[-1][::+1], + pos1, + pos2, + relpose=pose_emb1, + num_cls=self.dec_num_cls, + ) + # img2 side + f2, _ = blk2( + *final_output[-1][::-1], + pos2, + pos1, + relpose=pose_emb2, + num_cls=self.dec_num_cls, + ) + # store the result + final_output.append((f1, f2)) + + del final_output[1] # duplicate with final_output[0] (after decoder proj) + if self.dec_cls: # remove cls token for decoder layers + final_output[1:] = [(f1[:, 1:], f2[:, 1:]) for f1, f2 in final_output[1:]] + # normalize last output + final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) + return zip(*final_output) + + def _downstream_head(self, head_num, decout, img_shape): + B, S, D = decout[-1].shape + head = getattr(self, f"head{head_num}") + return head(decout, img_shape) + + def forward(self, view1, view2): + # encode the two images --> B,S,D + (shape1, shape2), (feat1, feat2), (pos1, pos2) = self.encode_symmetrized( + view1, view2 + ) + + # combine all ref images into object-centric representation + dec1, dec2 = self._decoder( + feat1, + pos1, + feat2, + pos2, + relpose1=view1.get("known_pose"), + relpose2=view2.get("known_pose"), + ) + with torch.autocast("cuda", enabled=False): + res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) + res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) + + res2["pts3d_in_other_view"] = res2.pop( + "pts3d" + ) # predict view2's pts3d in view1's frame + return res1, res2 + + +def convert_release_dust3r_args(args): + args.model = ( + args.model.replace("patch_embed_cls", "patch_embed") + .replace("AsymmetricMASt3R", "AsymmetricCroCo3DStereo") + .replace("PatchEmbedDust3R", "convManyAR") + .replace( + "pos_embed='RoPE100'", + "enc_pos_embed='cuRoPE100', dec_pos_embed='cuRoPE100'", + ) + ) + return args + + +def _load_model(model_path, device): + print("... loading model from", model_path) + ckpt = torch.load(model_path, map_location="cpu") + try: + net = eval( + ckpt["args"].model[:-1].replace("convManyAR", "convP") + + ", landscape_only=False)" + ) + except Exception: + args = convert_release_dust3r_args(ckpt["args"]) + net = eval( + args.model[:-1].replace("convManyAR", "convP") + ", landscape_only=False)" + ) + ckpt["model"] = { + k.replace("_downstream_head", "downstream_head"): v + for k, v in ckpt["model"].items() + } + print(net.load_state_dict(ckpt["model"], strict=False)) + return net.to(device) + + +def cat_cls(cls, tokens, pos): + tokens = torch.cat((cls, tokens), dim=1) + pos = torch.cat((-pos.new_ones(len(cls), 1, 2), pos), dim=1) + return tokens, pos + + +class Pow3RWrapper(torch.nn.Module): + def __init__( + self, + name, + ckpt_path, + geometric_input_config, + **kwargs, + ): + super().__init__() + self.name = name + self.ckpt_path = ckpt_path + self.geometric_input_config = geometric_input_config + + # Init the model and load the checkpoint + print(f"Loading checkpoint from {self.ckpt_path} ...") + ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False) + model = ckpt["definition"] + print(f"Creating model = {model}") + self.model = eval(model) + print(self.model.load_state_dict(ckpt["weights"])) + + def forward(self, views): + """ + Forward pass wrapper for Pow3R. + + Assumption: + - The number of input views is 2. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Length of the list should be 2. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs: + "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3). + "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation. + "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1). + + Returns: + List[dict]: A list containing the final outputs for the two views. Length of the list will be 2. + """ + # Check that the number of input views is 2 + assert len(views) == 2, "Pow3R requires 2 input views." + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "Pow3R expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Get the batch size per view, device and two views + batch_size_per_view = views[0]["img"].shape[0] + device = views[0]["img"].device + view1, view2 = views + + # Decide if we need to use the geometric inputs + if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]: + # Decide if we need to use the camera intrinsics + if ( + torch.rand(1, device=device) + < self.geometric_input_config["ray_dirs_prob"] + ): + add_intrinsics(view1, view1.get("camera_intrinsics")) + add_intrinsics(view2, view2.get("camera_intrinsics")) + + # Decide if we need to use the depth map + if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]: + depthmap1 = view1.get("depthmap") + depthmap2 = view2.get("depthmap") + if depthmap1 is not None: + depthmap1 = depthmap1.squeeze(-1).to(device) + if depthmap2 is not None: + depthmap2 = depthmap2.squeeze(-1).to(device) + add_depth(view1, depthmap1) + add_depth(view2, depthmap2) + + # Decide if we need to use the camera pose + if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]: + cam1 = view1.get("camera_pose") + cam2 = view2.get("camera_pose") + add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1) + add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1) + + # Get the model predictions + preds = self.model(view1, view2) + + # Convert the output to MapAnything format + with torch.autocast("cuda", enabled=False): + res = [] + for view_idx in range(2): + # Get the model predictions for the current view + curr_view_pred = preds[view_idx] + + # For the first view + if view_idx == 0: + # Get the global frame and camera frame pointmaps + global_pts = curr_view_pred["pts3d"] + cam_pts = curr_view_pred["pts3d"] + conf = curr_view_pred["conf"] + + # Get the ray directions and depth along ray + depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True) + ray_directions = cam_pts / depth_along_ray + + # Initalize identity camera pose + cam_rot = torch.eye(3, device=device) + cam_quat = mat_to_quat(cam_rot) + cam_trans = torch.zeros(3, device=device) + cam_quat = cam_quat.unsqueeze(0).repeat(batch_size_per_view, 1) + cam_trans = cam_trans.unsqueeze(0).repeat(batch_size_per_view, 1) + # For the second view + elif view_idx == 1: + # Get the global frame and camera frame pointmaps + pred_global_pts = curr_view_pred["pts3d_in_other_view"] + cam_pts = curr_view_pred["pts3d2"] + conf = (curr_view_pred["conf"] * curr_view_pred["conf2"]).sqrt() + + # Get the ray directions and depth along ray + depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True) + ray_directions = cam_pts / depth_along_ray + + # Compute the camera pose using the pointmaps + cam_rot, cam_trans, scale = roma.rigid_points_registration( + cam_pts.reshape(batch_size_per_view, -1, 3), + pred_global_pts.reshape(batch_size_per_view, -1, 3), + weights=conf.reshape(batch_size_per_view, -1), + compute_scaling=True, + ) + cam_quat = mat_to_quat(cam_rot) + + # Scale the predicted camera frame pointmap and compute the new global frame pointmap + cam_pts = scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * cam_pts + global_pts = cam_pts.reshape( + batch_size_per_view, -1, 3 + ) @ cam_rot.permute(0, 2, 1) + cam_trans.unsqueeze(1) + global_pts = global_pts.view(pred_global_pts.shape) + + # Append the result in MapAnything format + res.append( + { + "pts3d": global_pts, + "pts3d_cam": cam_pts, + "ray_directions": ray_directions, + "depth_along_ray": depth_along_ray, + "cam_trans": cam_trans, + "cam_quats": cam_quat, + "conf": conf, + } + ) + + return res + + +class Pow3RBAWrapper(torch.nn.Module): + def __init__( + self, + name, + ckpt_path, + geometric_input_config, + scene_graph="complete", + inference_batch_size=32, + global_optim_schedule="cosine", + global_optim_lr=0.01, + global_optim_niter=300, + **kwargs, + ): + super().__init__() + self.name = name + self.ckpt_path = ckpt_path + self.geometric_input_config = geometric_input_config + self.scene_graph = scene_graph + self.inference_batch_size = inference_batch_size + self.global_optim_schedule = global_optim_schedule + self.global_optim_lr = global_optim_lr + self.global_optim_niter = global_optim_niter + + # Init the model and load the checkpoint + print(f"Loading checkpoint from {self.ckpt_path} ...") + ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False) + model = ckpt["definition"] + print(f"Creating model = {model}") + self.model = eval(model) + print(self.model.load_state_dict(ckpt["weights"])) + + # Init the global aligner mode + self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer + + def infer_two_views(self, views): + """ + Wrapper for Pow3R 2-View inference. + + Assumption: + - The number of input views is 2. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Length of the list should be 2. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs: + "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3). + "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation. + "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1). + + Returns: + List[dict]: A list containing the final outputs for the two views. Length of the list will be 2. + """ + # Check that the number of input views is 2 + assert len(views) == 2, "Pow3R requires 2 input views." + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "Pow3R expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Get the device and two views + device = views[0]["img"].device + view1, view2 = views + + # Decide if we need to use the geometric inputs + if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]: + # Decide if we need to use the camera intrinsics + if ( + torch.rand(1, device=device) + < self.geometric_input_config["ray_dirs_prob"] + ): + add_intrinsics(view1, view1.get("camera_intrinsics")) + add_intrinsics(view2, view2.get("camera_intrinsics")) + + # Decide if we need to use the depth map + if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]: + depthmap1 = view1.get("depthmap") + depthmap2 = view2.get("depthmap") + if depthmap1 is not None: + depthmap1 = depthmap1.squeeze(-1).to(device) + if depthmap2 is not None: + depthmap2 = depthmap2.squeeze(-1).to(device) + add_depth(view1, depthmap1) + add_depth(view2, depthmap2) + + # Decide if we need to use the camera pose + if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]: + cam1 = view1.get("camera_pose") + cam2 = view2.get("camera_pose") + add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1) + add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1) + + # Get the model predictions + preds = self.model(view1, view2) + + return preds + + def loss_of_one_batch(self, batch, device): + """ + Compute prediction for two views. + """ + view1, view2 = batch + ignore_keys = set( + [ + "dataset", + "label", + "instance", + "idx", + "true_shape", + "rng", + "name", + "data_norm_type", + ] + ) + for view in batch: + for name in view.keys(): # pseudo_focal + if name in ignore_keys: + continue + view[name] = view[name].to(device, non_blocking=True) + + pred1, pred2 = self.infer_two_views([view1, view2]) + + result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2) + + return result + + @torch.no_grad() + def inference(self, pairs, device, verbose=False): + """ + Wrapper for multi-pair inference using Pow3R. + """ + if verbose: + print(f">> Inference with model on {len(pairs)} image pairs") + result = [] + + multiple_shapes = not (check_if_same_size(pairs)) + if multiple_shapes: + self.inference_batch_size = 1 + + for i in tqdm.trange( + 0, len(pairs), self.inference_batch_size, disable=not verbose + ): + res = self.loss_of_one_batch( + collate_with_cat(pairs[i : i + self.inference_batch_size]), device + ) + result.append(to_cpu(res)) + + result = collate_with_cat(result, lists=multiple_shapes) + + return result + + def forward(self, views): + """ + Forward pass wrapper for Pow3R using the global aligner. + + Assumption: + - The batch size of input views is 1. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys, where B is the batch size and is 1: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + + Returns: + List[dict]: A list containing the final outputs for the input views. + """ + # Check the batch size of input views + batch_size_per_view, _, height, width = views[0]["img"].shape + device = views[0]["img"].device + num_views = len(views) + assert batch_size_per_view == 1, ( + f"Batch size of input views should be 1, but got {batch_size_per_view}." + ) + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "Pow3R-BA expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Convert the input views to the expected input format + images = [] + for view in views: + images.append( + dict( + img=view["img"], + camera_intrinsics=view["camera_intrinsics"], + depthmap=view["depthmap"], + camera_pose=view["camera_pose"], + data_norm_type=view["data_norm_type"], + true_shape=view["true_shape"], + idx=len(images), + instance=str(len(images)), + ) + ) + + # Make image pairs and run inference pair-wise + pairs = make_pairs( + images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + output = self.inference( + pairs, + device, + verbose=False, + ) + + # Global optimization + with torch.enable_grad(): + scene = global_aligner( + output, device=device, mode=self.global_aligner_mode, verbose=False + ) + _ = scene.compute_global_alignment( + init="mst", + niter=self.global_optim_niter, + schedule=self.global_optim_schedule, + lr=self.global_optim_lr, + ) + + # Make sure scene is not None + if scene is None: + raise RuntimeError("Global optimization failed.") + + # Get the predictions + intrinsics = scene.get_intrinsics() + c2w_poses = scene.get_im_poses() + depths = scene.get_depthmaps() + + # Convert the output to the MapAnything format + with torch.autocast("cuda", enabled=False): + res = [] + for view_idx in range(num_views): + # Get the current view predictions + curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0) + curr_view_pose = c2w_poses[view_idx].unsqueeze(0) + curr_view_depth_z = depths[view_idx].unsqueeze(0) + + # Convert the pose to quaternions and translation + curr_view_cam_translations = curr_view_pose[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) + + # Get the camera frame pointmaps + curr_view_pts3d_cam, _ = depthmap_to_camera_frame( + curr_view_depth_z, curr_view_intrinsic + ) + + # Convert the z depth to depth along ray + curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( + curr_view_depth_z, curr_view_intrinsic + ) + curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) + + # Get the ray directions on the unit sphere in the camera frame + _, curr_view_ray_dirs = get_rays_in_camera_frame( + curr_view_intrinsic, height, width, normalize_to_unit_sphere=True + ) + + # Get the pointmaps + curr_view_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + curr_view_ray_dirs, + curr_view_depth_along_ray, + curr_view_cam_translations, + curr_view_cam_quats, + ) + ) + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d, + "pts3d_cam": curr_view_pts3d_cam, + "ray_directions": curr_view_ray_dirs, + "depth_along_ray": curr_view_depth_along_ray, + "cam_trans": curr_view_cam_translations, + "cam_quats": curr_view_cam_quats, + } + ) + + return res diff --git a/mapanything/models/external/vggt/__init__.py b/mapanything/models/external/vggt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a182f2e6af72f1fb5667fce91ef83c86cb1fb0cf --- /dev/null +++ b/mapanything/models/external/vggt/__init__.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for VGGT +""" + +import torch + +from mapanything.models.external.vggt.models.vggt import VGGT +from mapanything.models.external.vggt.utils.geometry import closed_form_inverse_se3 +from mapanything.models.external.vggt.utils.pose_enc import pose_encoding_to_extri_intri +from mapanything.models.external.vggt.utils.rotation import mat_to_quat +from mapanything.utils.geometry import ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + convert_z_depth_to_depth_along_ray, + depthmap_to_camera_frame, + get_rays_in_camera_frame, +) + + +class VGGTWrapper(torch.nn.Module): + def __init__( + self, + name, + torch_hub_force_reload, + load_pretrained_weights=True, + depth=24, + num_heads=16, + intermediate_layer_idx=[4, 11, 17, 23], + load_custom_ckpt=False, + custom_ckpt_path=None, + ): + super().__init__() + self.name = name + self.torch_hub_force_reload = torch_hub_force_reload + self.load_custom_ckpt = load_custom_ckpt + self.custom_ckpt_path = custom_ckpt_path + + if load_pretrained_weights: + # Load pre-trained weights + if not torch_hub_force_reload: + # Initialize the 1B VGGT model from huggingface hub cache + print("Loading facebook/VGGT-1B from huggingface cache ...") + self.model = VGGT.from_pretrained( + "facebook/VGGT-1B", + ) + else: + # Initialize the 1B VGGT model + print("Re-downloading facebook/VGGT-1B ...") + self.model = VGGT.from_pretrained( + "facebook/VGGT-1B", force_download=True + ) + else: + # Load the VGGT class + self.model = VGGT( + depth=depth, + num_heads=num_heads, + intermediate_layer_idx=intermediate_layer_idx, + ) + + # Get the dtype for VGGT inference + # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) + self.dtype = ( + torch.bfloat16 + if torch.cuda.get_device_capability()[0] >= 8 + else torch.float16 + ) + + # Load custom checkpoint if requested + if self.load_custom_ckpt: + print(f"Loading checkpoint from {self.custom_ckpt_path} ...") + assert self.custom_ckpt_path is not None, ( + "custom_ckpt_path must be provided if load_custom_ckpt is set to True" + ) + custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False) + print(self.model.load_state_dict(custom_ckpt, strict=True)) + del custom_ckpt # in case it occupies memory + + def forward(self, views): + """ + Forward pass wrapper for VGGT + + Assumption: + - All the input views have the same image shape. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["identity"] + + Returns: + List[dict]: A list containing the final outputs for all N views. + """ + # Get input shape of the images, number of views, and batch size per view + batch_size_per_view, _, height, width = views[0]["img"].shape + num_views = len(views) + + # Check the data norm type + # VGGT expects a normalized image but without the DINOv2 mean and std applied ("identity") + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "identity", ( + "VGGT expects a normalized image but without the DINOv2 mean and std applied" + ) + + # Concatenate the images to create a single (B, V, C, H, W) tensor + img_list = [view["img"] for view in views] + images = torch.stack(img_list, dim=1) + + # Run the VGGT aggregator + with torch.autocast("cuda", dtype=self.dtype): + aggregated_tokens_list, ps_idx = self.model.aggregator(images) + + # Run the Camera + Pose Branch of VGGT + with torch.autocast("cuda", enabled=False): + # Predict Cameras + pose_enc = self.model.camera_head(aggregated_tokens_list)[-1] + # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world) + # Extrinsics Shape: (B, V, 3, 4) + # Intrinsics Shape: (B, V, 3, 3) + extrinsic, intrinsic = pose_encoding_to_extri_intri( + pose_enc, images.shape[-2:] + ) + + # Predict Depth Maps + # Depth Shape: (B, V, H, W, 1) + # Depth Confidence Shape: (B, V, H, W) + depth_map, depth_conf = self.model.depth_head( + aggregated_tokens_list, images, ps_idx + ) + + # Convert the output to MapAnything format + res = [] + for view_idx in range(num_views): + # Get the extrinsics, intrinsics, depth map for the current view + curr_view_extrinsic = extrinsic[:, view_idx, ...] + curr_view_extrinsic = closed_form_inverse_se3( + curr_view_extrinsic + ) # Convert to cam2world + curr_view_intrinsic = intrinsic[:, view_idx, ...] + curr_view_depth_z = depth_map[:, view_idx, ...] + curr_view_depth_z = curr_view_depth_z.squeeze(-1) + curr_view_confidence = depth_conf[:, view_idx, ...] + + # Get the camera frame pointmaps + curr_view_pts3d_cam, _ = depthmap_to_camera_frame( + curr_view_depth_z, curr_view_intrinsic + ) + + # Convert the extrinsics to quaternions and translations + curr_view_cam_translations = curr_view_extrinsic[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3]) + + # Convert the z depth to depth along ray + curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( + curr_view_depth_z, curr_view_intrinsic + ) + curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) + + # Get the ray directions on the unit sphere in the camera frame + _, curr_view_ray_dirs = get_rays_in_camera_frame( + curr_view_intrinsic, height, width, normalize_to_unit_sphere=True + ) + + # Get the pointmaps + curr_view_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + curr_view_ray_dirs, + curr_view_depth_along_ray, + curr_view_cam_translations, + curr_view_cam_quats, + ) + ) + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d, + "pts3d_cam": curr_view_pts3d_cam, + "ray_directions": curr_view_ray_dirs, + "depth_along_ray": curr_view_depth_along_ray, + "cam_trans": curr_view_cam_translations, + "cam_quats": curr_view_cam_quats, + "conf": curr_view_confidence, + } + ) + + return res diff --git a/mapanything/models/external/vggt/heads/__init__.py b/mapanything/models/external/vggt/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/vggt/heads/camera_head.py b/mapanything/models/external/vggt/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..17efa894a60d8842ff6c5e789c3f95fc6331c791 --- /dev/null +++ b/mapanything/models/external/vggt/heads/camera_head.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn + +from mapanything.models.external.vggt.heads.head_act import activate_pose +from mapanything.models.external.vggt.layers import Mlp +from mapanything.models.external.vggt.layers.block import Block + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block( + dim=dim_in, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=init_values, + ) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True) + ) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp( + in_features=dim_in, + hidden_features=dim_in // 2, + out_features=self.target_dim, + drop=0, + ) + + def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape # S is expected to be 1. + pred_pose_enc = None + pred_pose_enc_list = [] + + for _ in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk( + 3, dim=-1 + ) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate( + self.adaln_norm(pose_tokens), shift_msa, scale_msa + ) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + pose_tokens_modulated = self.trunk(pose_tokens_modulated) + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch( + self.trunk_norm(pose_tokens_modulated) + ) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, + trans_act=self.trans_act, + quat_act=self.quat_act, + fl_act=self.fl_act, + ) + pred_pose_enc_list.append(activated_pose) + + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift diff --git a/mapanything/models/external/vggt/heads/dpt_head.py b/mapanything/models/external/vggt/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7c138c67c73cbadf615a96ff13bc1d7a09dac980 --- /dev/null +++ b/mapanything/models/external/vggt/heads/dpt_head.py @@ -0,0 +1,600 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 + + +from typing import List, Tuple, Union + +import torch +import torch.nn as nn + +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + DPT Head for dense prediction tasks. + + This implementation follows the architecture described in "Vision Transformers for Dense Prediction" + (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer + backbone and produces dense predictions by fusing multi-scale features. + + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=oc, + kernel_size=1, + stride=1, + padding=0, + ) + for oc in out_channels + ] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0, + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1, + ), + ] + ) + + self.scratch = _make_scratch( + out_channels, + features, + expand=False, + ) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1, kernel_size=3, stride=1, padding=1 + ) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, + head_features_1 // 2, + kernel_size=3, + stride=1, + padding=1, + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d( + conv2_in_channels, + head_features_2, + kernel_size=3, + stride=1, + padding=1, + ), + nn.ReLU(inplace=True), + nn.Conv2d( + head_features_2, output_dim, kernel_size=1, stride=1, padding=0 + ), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = images.shape + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, + images, + patch_start_idx, + frames_start_idx, + frames_end_idx, + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, + images, + patch_start_idx, + frames_start_idx, + frames_end_idx, + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Implementation of the forward pass through the DPT head. + + This method processes a specific chunk of frames from the sequence. + + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx].contiguous() + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.reshape(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + ( + int(patch_h * self.patch_size / self.down_ratio), + int(patch_w * self.patch_size / self.down_ratio), + ), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.view(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head( + out, activation=self.activation, conf_activation=self.conf_activation + ) + + preds = preds.view(B, S, *preds.shape[1:]) + conf = conf.view(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed( + self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1 + ) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid( + patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device + ) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +################################################################################ +# Modules +################################################################################ + + +def _make_fusion_block( + features: int, size: int = None, has_residual: bool = True, groups: int = 1 +) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch( + in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False +) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=self.groups, + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit( + features, activation, bn, groups=self.groups + ) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit( + features, activation, bn, groups=self.groups + ) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate( + chunk, size=size, mode=mode, align_corners=align_corners + ) + for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate( + x, size=size, mode=mode, align_corners=align_corners + ) diff --git a/mapanything/models/external/vggt/heads/head_act.py b/mapanything/models/external/vggt/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..acf073f9a59d5901422a342ea61372591f90fe57 --- /dev/null +++ b/mapanything/models/external/vggt/heads/head_act.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def activate_pose( + pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear" +): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/mapanything/models/external/vggt/heads/track_head.py b/mapanything/models/external/vggt/heads/track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3464b871a7b7ad3976daf51f4d08022bd6d71ac8 --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_head.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from .dpt_head import DPTHead +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackHead(nn.Module): + """ + Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. + The tracking is performed iteratively, refining predictions over multiple iterations. + """ + + def __init__( + self, + dim_in, + patch_size=14, + features=128, + iters=4, + predict_conf=True, + stride=2, + corr_levels=7, + corr_radius=4, + hidden_size=384, + ): + """ + Initialize the TrackHead module. + + Args: + dim_in (int): Input dimension of tokens from the backbone. + patch_size (int): Size of image patches used in the vision transformer. + features (int): Number of feature channels in the feature extractor output. + iters (int): Number of refinement iterations for tracking predictions. + predict_conf (bool): Whether to predict confidence scores for tracked points. + stride (int): Stride value for the tracker predictor. + corr_levels (int): Number of correlation pyramid levels + corr_radius (int): Radius for correlation computation, controlling the search area. + hidden_size (int): Size of hidden layers in the tracker network. + """ + super().__init__() + + self.patch_size = patch_size + + # Feature extractor based on DPT architecture + # Processes tokens into feature maps for tracking + self.feature_extractor = DPTHead( + dim_in=dim_in, + patch_size=patch_size, + features=features, + feature_only=True, # Only output features, no activation + down_ratio=2, # Reduces spatial dimensions by factor of 2 + pos_embed=False, + ) + + # Tracker module that predicts point trajectories + # Takes feature maps and predicts coordinates and visibility + self.tracker = BaseTrackerPredictor( + latent_dim=features, # Match the output_dim of feature extractor + predict_conf=predict_conf, + stride=stride, + corr_levels=corr_levels, + corr_radius=corr_radius, + hidden_size=hidden_size, + ) + + self.iters = iters + + def forward( + self, + aggregated_tokens_list, + images, + patch_start_idx, + query_points=None, + iters=None, + ): + """ + Forward pass of the TrackHead. + + Args: + aggregated_tokens_list (list): List of aggregated tokens from the backbone. + images (torch.Tensor): Input images of shape (B, S, C, H, W) where: + B = batch size, S = sequence length. + patch_start_idx (int): Starting index for patch tokens. + query_points (torch.Tensor, optional): Initial query points to track. + If None, points are initialized by the tracker. + iters (int, optional): Number of refinement iterations. If None, uses self.iters. + + Returns: + tuple: + - coord_preds (torch.Tensor): Predicted coordinates for tracked points. + - vis_scores (torch.Tensor): Visibility scores for tracked points. + - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). + """ + B, S, _, H, W = images.shape + + # Extract features from tokens + # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 + feature_maps = self.feature_extractor( + aggregated_tokens_list, images, patch_start_idx + ) + + # Use default iterations if not specified + if iters is None: + iters = self.iters + + # Perform tracking using the extracted features + coord_preds, vis_scores, conf_scores = self.tracker( + query_points=query_points, + fmaps=feature_maps, + iters=iters, + ) + + return coord_preds, vis_scores, conf_scores diff --git a/mapanything/models/external/vggt/heads/track_modules/__init__.py b/mapanything/models/external/vggt/heads/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4196294309799347172dba54a17360698071ca8 --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py b/mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..204a82ac099c7cda89c2cc66e4290b2e3da33542 --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange + +from .blocks import CorrBlock, EfficientUpdateFormer +from .modules import Mlp +from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=1, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + max_scale=518, + predict_conf=True, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + and https://github.com/facebookresearch/vggsfm + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.max_scale = max_scale + self.predict_conf = predict_conf + + self.flows_emb_dim = latent_dim // 2 + + self.corr_mlp = Mlp( + in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, + hidden_features=self.hidden_size, + out_features=self.latent_dim, + ) + + self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 + + self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.fmap_norm = nn.LayerNorm(self.latent_dim) + self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), nn.GELU() + ) + + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + if predict_conf: + self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward( + self, + query_points, + fmaps=None, + iters=6, + return_feat=False, + down_ratio=1, + apply_sigmoid=True, + ): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2, "Input points must be 2D coordinates" + + # apply a layernorm to fmaps here + fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) + fmaps = fmaps.permute(0, 1, 4, 2, 3) + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + fcorr_fn = CorrBlock( + fmaps, num_levels=self.corr_levels, radius=self.corr_radius + ) + + coord_preds = [] + + # Iterative Refinement + for _ in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + fcorrs = fcorr_fn.corr_sample(track_feats, coords) + + corr_dim = fcorrs.shape[3] + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) + fcorrs_ = self.corr_mlp(fcorrs_) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat( + [flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1 + ) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape( + B * N, S, self.latent_dim + ) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed( + self.transformer_dim, grid_size=(HH, WW) + ).to(query_points.device) + sampled_pos_emb = sample_features4d( + pos_embed.expand(B, -1, -1, -1), coords[:, 0] + ) + + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze( + 1 + ) + + x = transformer_input + sampled_pos_emb + + # Add the query ref token to the track feats + query_ref_token = torch.cat( + [ + self.query_ref_token[:, 0:1], + self.query_ref_token[:, 1:2].expand(-1, S - 1, -1), + ], + dim=1, + ) + x = x + query_ref_token.to(x.device).to(x.dtype) + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta, _ = self.updateformer(x) + + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = ( + self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ + ) + + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute( + 0, 2, 1, 3 + ) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + vis_e = self.vis_predictor( + track_feats.reshape(B * S * N, self.latent_dim) + ).reshape(B, S, N) + if apply_sigmoid: + vis_e = torch.sigmoid(vis_e) + + if self.predict_conf: + conf_e = self.conf_predictor( + track_feats.reshape(B * S * N, self.latent_dim) + ).reshape(B, S, N) + if apply_sigmoid: + conf_e = torch.sigmoid(conf_e) + else: + conf_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat, conf_e + else: + return coord_preds, vis_e, conf_e diff --git a/mapanything/models/external/vggt/heads/track_modules/blocks.py b/mapanything/models/external/vggt/heads/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..4285ce15bac49153ead21fb9ab5ad50c5bf555ac --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_modules/blocks.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modules import AttnBlock, CrossAttnBlock +from .utils import bilinear_sampler + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + + # Add input LayerNorm before linear projection + self.input_norm = nn.LayerNorm(input_dim) + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + + # Add output LayerNorm before final projection + self.output_norm = nn.LayerNorm(hidden_size) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter( + torch.randn(1, num_virtual_tracks, 1, hidden_size) + ) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [ + CrossAttnBlock( + hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio + ) + for _ in range(space_depth) + ] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [ + CrossAttnBlock( + hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio + ) + for _ in range(space_depth) + ] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) + + self.apply(_basic_init) + + def forward(self, input_tensor, mask=None): + # Apply input LayerNorm + input_tensor = self.input_norm(input_tensor) + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and ( + i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0 + ): + space_tokens = ( + tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) + ) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j]( + virtual_tokens, point_tokens, mask=mask + ) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j]( + point_tokens, virtual_tokens, mask=mask + ) + + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute( + 0, 2, 1, 3 + ) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + # Apply output LayerNorm before final projection + tokens = self.output_norm(tokens) + flow = self.flow_head(tokens) + + return flow, None + + +class CorrBlock: + def __init__( + self, + fmaps, + num_levels=4, + radius=4, + multiple_track_feats=False, + padding_mode="zeros", + ): + """ + Build a pyramid of feature maps from the input. + + fmaps: Tensor (B, S, C, H, W) + num_levels: number of pyramid levels (each downsampled by factor 2) + radius: search radius for sampling correlation + multiple_track_feats: if True, split the target features per pyramid level + padding_mode: passed to grid_sample / bilinear_sampler + """ + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.num_levels = num_levels + self.radius = radius + self.padding_mode = padding_mode + self.multiple_track_feats = multiple_track_feats + + # Build pyramid: each level is half the spatial resolution of the previous + self.fmaps_pyramid = [fmaps] # level 0 is full resolution + current_fmaps = fmaps + for i in range(num_levels - 1): + B, S, C, H, W = current_fmaps.shape + # Merge batch & sequence dimensions + current_fmaps = current_fmaps.reshape(B * S, C, H, W) + # Avg pool down by factor 2 + current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) + _, _, H_new, W_new = current_fmaps.shape + current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) + self.fmaps_pyramid.append(current_fmaps) + + # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. + # This grid is added to the (scaled) coordinate centroids. + r = self.radius + dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + # delta: for every (dy,dx) displacement (i.e. Δx, Δy) + self.delta = torch.stack( + torch.meshgrid(dy, dx, indexing="ij"), dim=-1 + ) # shape: (2r+1, 2r+1, 2) + + def corr_sample(self, targets, coords): + """ + Instead of storing the entire correlation pyramid, we compute each level's correlation + volume, sample it immediately, then discard it. This saves GPU memory. + + Args: + targets: Tensor (B, S, N, C) — features for the current targets. + coords: Tensor (B, S, N, 2) — coordinates at full resolution. + + Returns: + Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) + """ + B, S, N, C = targets.shape + + # If you have multiple track features, split them per level. + if self.multiple_track_feats: + targets_split = torch.split(targets, C // self.num_levels, dim=-1) + + out_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + # Get current spatial resolution H, W for this pyramid level. + B, S, C, H, W = fmaps.shape + # Reshape feature maps for correlation computation: + # fmap2s: (B, S, C, H*W) + fmap2s = fmaps.view(B, S, C, H * W) + # Choose appropriate target features. + fmap1 = ( + targets_split[i] if self.multiple_track_feats else targets + ) # shape: (B, S, N, C) + + # Compute correlation directly + corrs = compute_corr_level(fmap1, fmap2s, C) + corrs = corrs.view(B, S, N, H, W) + + # Prepare sampling grid: + # Scale down the coordinates for the current level. + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) + # Make sure our precomputed delta grid is on the same device/dtype. + delta_lvl = self.delta.to(coords.device).to(coords.dtype) + # Now the grid for grid_sample is: + # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) + coords_lvl = centroid_lvl + delta_lvl.view( + 1, 2 * self.radius + 1, 2 * self.radius + 1, 2 + ) + + # Sample from the correlation volume using bilinear interpolation. + # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. + corrs_sampled = bilinear_sampler( + corrs.reshape(B * S * N, 1, H, W), + coords_lvl, + padding_mode=self.padding_mode, + ) + # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. + corrs_sampled = corrs_sampled.view( + B, S, N, -1 + ) # Now shape: (B, S, N, (2r+1)^2) + out_pyramid.append(corrs_sampled) + + # Concatenate all levels along the last dimension. + out = torch.cat(out_pyramid, dim=-1).contiguous() + return out + + +def compute_corr_level(fmap1, fmap2s, C): + # fmap1: (B, S, N, C) + # fmap2s: (B, S, C, H*W) + corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) + corrs = corrs.view( + fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1 + ) # (B, S, N, H*W) + return corrs / math.sqrt(C) diff --git a/mapanything/models/external/vggt/heads/track_modules/modules.py b/mapanything/models/external/vggt/heads/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..dbfb3825f8bafc8e33e2d33c3d3b390186d59c86 --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_modules/modules.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import collections +from functools import partial +from itertools import repeat +from typing import Callable + +import torch.nn as nn + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=kernel_size, + padding=1, + stride=stride, + padding_mode="zeros", + ) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=kernel_size, + padding=1, + padding_mode="zeros", + ) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), + self.norm3, + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.attn = attn_class( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__( + self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs + ): + """ + Cross attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/mapanything/models/external/vggt/heads/track_modules/utils.py b/mapanything/models/external/vggt/heads/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcb15170d6a1639004ed4a7bda8c48e9430ef1b --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_modules/utils.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/vggsfm +# and https://github.com/facebookresearch/co-tracker/tree/main + + +from typing import Tuple, Union + +import torch +import torch.nn.functional as F + + +def get_2d_sincos_pos_embed( + embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False +) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return ( + pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), + grid, + ) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: torch.Tensor +) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: torch.Tensor +) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = ( + torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) + ).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + coords = coords.detach().clone() + ############################################################ + # IMPORTANT: + coords = coords.to(input.device).to(input.dtype) + ############################################################ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + scale = torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], + device=coords.device, + dtype=coords.dtype, + ) + else: + scale = torch.tensor( + [2 / size for size in reversed(sizes)], + device=coords.device, + dtype=coords.dtype, + ) + + coords.mul_(scale) # coords = coords * scale + coords.sub_(1) # coords = coords - 1 + + return F.grid_sample( + input, coords, align_corners=align_corners, padding_mode=padding_mode + ) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view( + B, -1, feats.shape[1] * feats.shape[3] + ) # B C R 1 -> B R C diff --git a/mapanything/models/external/vggt/heads/utils.py b/mapanything/models/external/vggt/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c23a6e137f69c262a3e24f438b849c901d21158a --- /dev/null +++ b/mapanything/models/external/vggt/heads/utils.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def position_grid_to_embed( + pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100 +) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed( + embed_dim // 2, pos_flat[:, 0], omega_0=omega_0 + ) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed( + embed_dim // 2, pos_flat[:, 1], omega_0=omega_0 + ) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed( + embed_dim: int, pos: torch.Tensor, omega_0: float = 100 +) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + device = pos.device + omega = torch.arange( + embed_dim // 2, + dtype=torch.float32 if device.type == "mps" else torch.double, + device=device, + ) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, + height: int, + aspect_ratio: float = None, + dtype: torch.dtype = None, + device: torch.device = None, +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid diff --git a/mapanything/models/external/vggt/layers/__init__.py b/mapanything/models/external/vggt/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b84d2d861eb72cdb1ebc5e7c001938df0b851486 --- /dev/null +++ b/mapanything/models/external/vggt/layers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed + +__all__ = [ + "Mlp", + "PatchEmbed", +] diff --git a/mapanything/models/external/vggt/layers/attention.py b/mapanything/models/external/vggt/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb536194c0bf49b744ca9c819655bf9127ad0ed --- /dev/null +++ b/mapanything/models/external/vggt/layers/attention.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + + +import torch.nn.functional as F +from torch import nn, Tensor + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +# class MemEffAttention(Attention): +# def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: +# assert pos is None +# if not XFORMERS_AVAILABLE: +# if attn_bias is not None: +# raise AssertionError("xFormers is required for using nested tensors") +# return super().forward(x) + +# B, N, C = x.shape +# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + +# q, k, v = unbind(qkv, 2) + +# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) +# x = x.reshape([B, N, C]) + +# x = self.proj(x) +# x = self.proj_drop(x) +# return x diff --git a/mapanything/models/external/vggt/layers/block.py b/mapanything/models/external/vggt/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6d562dfa009034081062216d4d98b0662c024e --- /dev/null +++ b/mapanything/models/external/vggt/layers/block.py @@ -0,0 +1,280 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Any, Callable, Dict, List, Tuple + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + +XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + pos=pos, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, pos=pos) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, + pos=None, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + pass + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + # attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + # attn_bias._batch_sizes = batch_sizes + # attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + pass + # cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + # 1, -1, x_list[0].shape[-1] + # ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + # assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/mapanything/models/external/vggt/layers/drop_path.py b/mapanything/models/external/vggt/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..04cb47af065ec3cfdbc8f59854efaa976ea717d5 --- /dev/null +++ b/mapanything/models/external/vggt/layers/drop_path.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/mapanything/models/external/vggt/layers/layer_scale.py b/mapanything/models/external/vggt/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..b32da3bd74a028795f5ee628d2636a45b756b074 --- /dev/null +++ b/mapanything/models/external/vggt/layers/layer_scale.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import nn, Tensor + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/mapanything/models/external/vggt/layers/mlp.py b/mapanything/models/external/vggt/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6d19f53d8562d07d559fe6db93d45b8286420720 --- /dev/null +++ b/mapanything/models/external/vggt/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import nn, Tensor + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/mapanything/models/external/vggt/layers/patch_embed.py b/mapanything/models/external/vggt/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..493774d038c9ee7f0f63b05f80561fa61321e2b7 --- /dev/null +++ b/mapanything/models/external/vggt/layers/patch_embed.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +import torch.nn as nn +from torch import Tensor + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, ( + f"Input image height {H} is not a multiple of patch height {patch_H}" + ) + assert W % patch_W == 0, ( + f"Input image width {W} is not a multiple of patch width: {patch_W}" + ) + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/mapanything/models/external/vggt/layers/rope.py b/mapanything/models/external/vggt/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd3594bfd736752d5add0edd1986810578b253f --- /dev/null +++ b/mapanything/models/external/vggt/layers/rope.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +from typing import Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__( + self, batch_size: int, height: int, width: int, device: torch.device + ) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return ( + cached_positions.view(1, height * width, 2) + .expand(batch_size, -1, -1) + .clone() + ) + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, + tokens: torch.Tensor, + positions: torch.Tensor, + cos_comp: torch.Tensor, + sin_comp: torch.Tensor, + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert positions.ndim == 3 and positions.shape[-1] == 2, ( + "Positions must have shape (batch_size, n_tokens, 2)" + ) + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components( + feature_dim, max_position, tokens.device, tokens.dtype + ) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope( + vertical_features, positions[..., 0], cos_comp, sin_comp + ) + horizontal_features = self._apply_1d_rope( + horizontal_features, positions[..., 1], cos_comp, sin_comp + ) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/mapanything/models/external/vggt/layers/swiglu_ffn.py b/mapanything/models/external/vggt/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..9195a0ed27d4f06b0c0ba97c8a9bd7cc3214642e --- /dev/null +++ b/mapanything/models/external/vggt/layers/swiglu_ffn.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional + +import torch.nn.functional as F +from torch import nn, Tensor + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +# try: +# if XFORMERS_ENABLED: +# from xformers.ops import SwiGLU + +# XFORMERS_AVAILABLE = True +# warnings.warn("xFormers is available (SwiGLU)") +# else: +# warnings.warn("xFormers is disabled (SwiGLU)") +# raise ImportError +# except ImportError: +SwiGLU = SwiGLUFFN +XFORMERS_AVAILABLE = False + +# warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/mapanything/models/external/vggt/layers/vision_transformer.py b/mapanything/models/external/vggt/layers/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3d611608c29e528f1c6984768874e117ff660edb --- /dev/null +++ b/mapanything/models/external/vggt/layers/vision_transformer.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import math +from functools import partial +from typing import Callable, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.utils.checkpoint import checkpoint + +from . import ( + MemEffAttention, + Mlp, + NestedTensorBlock as Block, + PatchEmbed, + SwiGLUFFNFused, +) + +logger = logging.getLogger("dinov2") + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + qk_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + # tricky but makes it work + self.use_checkpoint = False + # + + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append( + [nn.Identity()] * i + blocks_list[i : i + chunksize] + ) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [ + self.prepare_tokens_with_masks(x, masks) + for x, masks in zip(x_list, masks_list) + ] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), ( + f"only {len(output)} / {len(blocks_to_take)} blocks found" + ) + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), ( + f"only {len(output)} / {len(blocks_to_take)} blocks found" + ) + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=True, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/mapanything/models/external/vggt/models/__init__.py b/mapanything/models/external/vggt/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/vggt/models/aggregator.py b/mapanything/models/external/vggt/models/aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..e082d34dc97f1dc58e21969de277086a218acdd3 --- /dev/null +++ b/mapanything/models/external/vggt/models/aggregator.py @@ -0,0 +1,385 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Tuple + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from mapanything.models.external.vggt.layers import PatchEmbed +from mapanything.models.external.vggt.layers.block import Block +from mapanything.models.external.vggt.layers.rope import ( + PositionGetter, + RotaryPositionEmbedding2D, +) + +logger = logging.getLogger(__name__) + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class Aggregator(nn.Module): + """ + The Aggregator applies alternating-attention over input frames, + as described in VGGT: Visual Geometry Grounded Transformer. + + + Args: + img_size (int): Image size in pixels. + patch_size (int): Size of each patch for PatchEmbed. + embed_dim (int): Dimension of the token embeddings. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. + num_register_tokens (int): Number of register tokens. + block_fn (nn.Module): The block type used for attention (Block by default). + qkv_bias (bool): Whether to include bias in QKV projections. + proj_bias (bool): Whether to include bias in the output projection. + ffn_bias (bool): Whether to include bias in MLP layers. + patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". + aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. + aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. + qk_norm (bool): Whether to apply QK normalization. + rope_freq (int): Base frequency for rotary embedding. -1 to disable. + init_values (float): Init scale for layer scale. + """ + + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + num_register_tokens=4, + block_fn=Block, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + patch_embed="dinov2_vitl14_reg", + aa_order=["frame", "global"], + aa_block_size=1, + qk_norm=True, + rope_freq=100, + init_values=0.01, + ): + super().__init__() + + self.__build_patch_embed__( + patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim + ) + + # Initialize rotary position embedding if frequency > 0 + self.rope = ( + RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + ) + self.position_getter = PositionGetter() if self.rope is not None else None + + self.frame_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.global_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.depth = depth + self.aa_order = aa_order + self.patch_size = patch_size + self.aa_block_size = aa_block_size + + # Validate that depth is divisible by aa_block_size + if self.depth % self.aa_block_size != 0: + raise ValueError( + f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})" + ) + + self.aa_block_num = self.depth // self.aa_block_size + + # Note: We have two camera tokens, one for the first frame and one for the rest + # The same applies for register tokens + self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) + self.register_token = nn.Parameter( + torch.randn(1, 2, num_register_tokens, embed_dim) + ) + + # The patch tokens start after the camera and register tokens + self.patch_start_idx = 1 + num_register_tokens + + # Initialize parameters with small values + nn.init.normal_(self.camera_token, std=1e-6) + nn.init.normal_(self.register_token, std=1e-6) + + # Register normalization constants as buffers + for name, value in ( + ("_resnet_mean", _RESNET_MEAN), + ("_resnet_std", _RESNET_STD), + ): + self.register_buffer( + name, + torch.FloatTensor(value).view(1, 1, 3, 1, 1), + persistent=False, + ) + + def __build_patch_embed__( + self, + patch_embed, + img_size, + patch_size, + num_register_tokens, + interpolate_antialias=True, + interpolate_offset=0.0, + block_chunks=0, + init_values=1.0, + embed_dim=1024, + ): + """ + Build the patch embed layer. If 'conv', we use a + simple PatchEmbed conv layer. Otherwise, we use a vision transformer. + """ + + if "conv" in patch_embed: + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=3, + embed_dim=embed_dim, + ) + else: + ### From original VGGT codebase: Doesn't load pre-trained DINOv2 weights + # vit_models = { + # "dinov2_vitl14_reg": vit_large, + # "dinov2_vitb14_reg": vit_base, + # "dinov2_vits14_reg": vit_small, + # "dinov2_vitg2_reg": vit_giant2, + # } + + # self.patch_embed = vit_models[patch_embed]( + # img_size=img_size, + # patch_size=patch_size, + # num_register_tokens=num_register_tokens, + # interpolate_antialias=interpolate_antialias, + # interpolate_offset=interpolate_offset, + # block_chunks=block_chunks, + # init_values=init_values, + # ) + + ### Use pre-trained DINOv2 with gradient checkpointing + self.patch_embed = torch.hub.load("facebookresearch/dinov2", patch_embed) + for i in range(len(self.patch_embed.blocks)): + self.patch_embed.blocks[i] = ( + self.wrap_module_with_gradient_checkpointing( + self.patch_embed.blocks[i] + ) + ) + + # Disable gradient updates for mask token + if hasattr(self.patch_embed, "mask_token"): + self.patch_embed.mask_token.requires_grad_(False) + + ### Gradient Checkpointing Wrapper from UniCeption: + def wrap_module_with_gradient_checkpointing(self, module: nn.Module): + """ + Wrapper for Gradient Checkpointing + References: https://github.com/microsoft/MoGe + """ + + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + def forward( + self, + images: torch.Tensor, + ) -> Tuple[List[torch.Tensor], int]: + """ + Args: + images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + + Returns: + (list[torch.Tensor], int): + The list of outputs from the attention blocks, + and the patch_start_idx indicating where patch tokens begin. + """ + B, S, C_in, H, W = images.shape + + if C_in != 3: + raise ValueError(f"Expected 3 input channels, got {C_in}") + + # Normalize images and reshape for patch embed + images = (images - self._resnet_mean) / self._resnet_std + + # Reshape to [B*S, C, H, W] for patch embedding + images = images.view(B * S, C_in, H, W) + patch_tokens = self.patch_embed.forward_features(images) + + if isinstance(patch_tokens, dict): + patch_tokens = patch_tokens["x_norm_patchtokens"] + + _, P, C = patch_tokens.shape + + # Expand camera and register tokens to match batch size and sequence length + camera_token = slice_expand_and_flatten(self.camera_token, B, S) + register_token = slice_expand_and_flatten(self.register_token, B, S) + + # Concatenate special tokens with patch tokens + tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) + + pos = None + if self.rope is not None: + pos = self.position_getter( + B * S, H // self.patch_size, W // self.patch_size, device=images.device + ) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = ( + torch.zeros(B * S, self.patch_start_idx, 2) + .to(images.device) + .to(pos.dtype) + ) + pos = torch.cat([pos_special, pos], dim=1) + + # update P because we added special tokens + _, P, C = tokens.shape + + frame_idx = 0 + global_idx = 0 + output_list = [] + + for _ in range(self.aa_block_num): + for attn_type in self.aa_order: + if attn_type == "frame": + tokens, frame_idx, frame_intermediates = ( + self._process_frame_attention( + tokens, B, S, P, C, frame_idx, pos=pos + ) + ) + elif attn_type == "global": + tokens, global_idx, global_intermediates = ( + self._process_global_attention( + tokens, B, S, P, C, global_idx, pos=pos + ) + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") + + for i in range(len(frame_intermediates)): + # concat frame and global intermediates, [B x S x P x 2C] + concat_inter = torch.cat( + [frame_intermediates[i], global_intermediates[i]], dim=-1 + ) + output_list.append(concat_inter) + + del concat_inter + del frame_intermediates + del global_intermediates + return output_list, self.patch_start_idx + + def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): + """ + Process frame attention blocks. We keep tokens in shape (B*S, P, C). + """ + # If needed, reshape tokens or positions: + if tokens.shape != (B * S, P, C): + tokens = tokens.view(B, S, P, C).view(B * S, P, C) + + if pos is not None and pos.shape != (B * S, P, 2): + pos = pos.view(B, S, P, 2).view(B * S, P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + tokens = self.frame_blocks[frame_idx](tokens, pos=pos) + frame_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, frame_idx, intermediates + + def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): + """ + Process global attention blocks. We keep tokens in shape (B, S*P, C). + """ + if tokens.shape != (B, S * P, C): + tokens = tokens.view(B, S, P, C).view(B, S * P, C) + + if pos is not None and pos.shape != (B, S * P, 2): + pos = pos.view(B, S, P, 2).view(B, S * P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + tokens = self.global_blocks[global_idx](tokens, pos=pos) + global_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, global_idx, intermediates + + +def slice_expand_and_flatten(token_tensor, B, S): + """ + Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: + 1) Uses the first position (index=0) for the first frame only + 2) Uses the second position (index=1) for all remaining frames (S-1 frames) + 3) Expands both to match batch size B + 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token + followed by (S-1) second-position tokens + 5) Flattens to (B*S, X, C) for processing + + Returns: + torch.Tensor: Processed tokens with shape (B*S, X, C) + """ + + # Slice out the "query" tokens => shape (1, 1, ...) + query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) + # Slice out the "other" tokens => shape (1, S-1, ...) + others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) + # Concatenate => shape (B, S, ...) + combined = torch.cat([query, others], dim=1) + + # Finally flatten => shape (B*S, ...) + combined = combined.view(B * S, *combined.shape[2:]) + return combined diff --git a/mapanything/models/external/vggt/models/vggt.py b/mapanything/models/external/vggt/models/vggt.py new file mode 100644 index 0000000000000000000000000000000000000000..dba4b6aeda98ed87eee6ab893346e037abf45095 --- /dev/null +++ b/mapanything/models/external/vggt/models/vggt.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin # used for model hub + +from mapanything.models.external.vggt.heads.camera_head import CameraHead +from mapanything.models.external.vggt.heads.dpt_head import DPTHead +from mapanything.models.external.vggt.heads.track_head import TrackHead +from mapanything.models.external.vggt.models.aggregator import Aggregator + + +class VGGT(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + intermediate_layer_idx=[4, 11, 17, 23], + ): + super().__init__() + + self.aggregator = Aggregator( + img_size=img_size, + patch_size=patch_size, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + ) + self.camera_head = CameraHead(dim_in=2 * embed_dim) + self.point_head = DPTHead( + dim_in=2 * embed_dim, + output_dim=4, + activation="inv_log", + conf_activation="expp1", + intermediate_layer_idx=intermediate_layer_idx, + ) + self.depth_head = DPTHead( + dim_in=2 * embed_dim, + output_dim=2, + activation="exp", + conf_activation="expp1", + intermediate_layer_idx=intermediate_layer_idx, + ) + self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) + + def forward( + self, + images: torch.Tensor, + query_points: torch.Tensor = None, + ): + """ + Forward pass of the VGGT model. + + Args: + images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. + Shape: [N, 2] or [B, N, 2], where N is the number of query points. + Default: None + + Returns: + dict: A dictionary containing the following predictions: + - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) + - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] + - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] + - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] + - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] + - images (torch.Tensor): Original input images, preserved for visualization + + If query_points is provided, also includes: + - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates + - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] + - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] + """ + + # If without batch dimension, add it + if len(images.shape) == 4: + images = images.unsqueeze(0) + if query_points is not None and len(query_points.shape) == 2: + query_points = query_points.unsqueeze(0) + + aggregated_tokens_list, patch_start_idx = self.aggregator(images) + + predictions = {} + + with torch.cuda.amp.autocast(enabled=False): + if self.camera_head is not None: + pose_enc_list = self.camera_head(aggregated_tokens_list) + predictions["pose_enc"] = pose_enc_list[ + -1 + ] # pose encoding of the last iteration + + if self.depth_head is not None: + depth, depth_conf = self.depth_head( + aggregated_tokens_list, + images=images, + patch_start_idx=patch_start_idx, + ) + predictions["depth"] = depth + predictions["depth_conf"] = depth_conf + + if self.point_head is not None: + pts3d, pts3d_conf = self.point_head( + aggregated_tokens_list, + images=images, + patch_start_idx=patch_start_idx, + ) + predictions["world_points"] = pts3d + predictions["world_points_conf"] = pts3d_conf + + if self.track_head is not None and query_points is not None: + track_list, vis, conf = self.track_head( + aggregated_tokens_list, + images=images, + patch_start_idx=patch_start_idx, + query_points=query_points, + ) + predictions["track"] = track_list[-1] # track of the last iteration + predictions["vis"] = vis + predictions["conf"] = conf + + predictions["images"] = images + + return predictions diff --git a/mapanything/models/external/vggt/utils/__init__.py b/mapanything/models/external/vggt/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/vggt/utils/geometry.py b/mapanything/models/external/vggt/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..17da2b863022a36afc1ad7607e6640120d04b389 --- /dev/null +++ b/mapanything/models/external/vggt/utils/geometry.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch + + +def unproject_depth_map_to_point_map( + depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray +) -> np.ndarray: + """ + Unproject a batch of depth maps to 3D world coordinates. + + Args: + depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) + extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) + intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) + + Returns: + np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) + """ + if isinstance(depth_map, torch.Tensor): + depth_map = depth_map.cpu().numpy() + if isinstance(extrinsics_cam, torch.Tensor): + extrinsics_cam = extrinsics_cam.cpu().numpy() + if isinstance(intrinsics_cam, torch.Tensor): + intrinsics_cam = intrinsics_cam.cpu().numpy() + + world_points_list = [] + for frame_idx in range(depth_map.shape[0]): + cur_world_points, _, _ = depth_to_world_coords_points( + depth_map[frame_idx].squeeze(-1), + extrinsics_cam[frame_idx], + intrinsics_cam[frame_idx], + ) + world_points_list.append(cur_world_points) + world_points_array = np.stack(world_points_list, axis=0) + + return world_points_array + + +def depth_to_world_coords_points( + depth_map: np.ndarray, + extrinsic: np.ndarray, + intrinsic: np.ndarray, + eps=1e-8, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). + """ + if depth_map is None: + return None, None, None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + # Multiply with the inverse of extrinsic matrix to transform to world coordinates + # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + # Apply the rotation and translation to the camera coordinates + world_coords_points = ( + np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world + ) # HxWx3, 3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points( + depth_map: np.ndarray, intrinsic: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + + Returns: + tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, ( + "Intrinsic matrix must have zero skew" + ) + + # Intrinsic parameters + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + # Generate grid of pixel coordinates + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + # Unproject to camera coordinates + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + # Stack to form camera coordinates + cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + return cam_coords + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + # Check if se3 is a numpy array or a torch tensor + is_numpy = isinstance(se3, np.ndarray) + + # Validate shapes + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + # Extract R and T if not provided + if R is None: + R = se3[:, :3, :3] # (N,3,3) + if T is None: + T = se3[:, :3, 3:] # (N,3,1) + + # Transpose R + if is_numpy: + # Compute the transpose of the rotation for NumPy + R_transposed = np.transpose(R, (0, 2, 1)) + # -R^T t for NumPy + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.transpose(1, 2) # (N,3,3) + top_right = -torch.bmm(R_transposed, T) # (N,3,1) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix diff --git a/mapanything/models/external/vggt/utils/load_fn.py b/mapanything/models/external/vggt/utils/load_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..afa0a2d493a39759e01ee2babbd29ad19b19d215 --- /dev/null +++ b/mapanything/models/external/vggt/utils/load_fn.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from PIL import Image +from torchvision import transforms as TF + + +def load_and_preprocess_images(image_path_list, mode="crop"): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. + + Args: + image_path_list (list): List of paths to image files + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + # Validate mode + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = 518 + + # First process all images and collect their shapes + for image_path in image_path_list: + # Open image + img = Image.open(image_path) + + # If there's an alpha channel, blend onto white background: + if img.mode == "RGBA": + # Create white background + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + # Alpha composite onto the white background + img = Image.alpha_composite(background, img) + + # Now convert to "RGB" (this step assigns white for transparent areas) + img = img.convert("RGB") + + width, height = img.size + + if mode == "pad": + # Make the largest dimension 518px while maintaining aspect ratio + if width >= height: + new_width = target_size + new_height = ( + round(height * (new_width / width) / 14) * 14 + ) # Make divisible by 14 + else: + new_height = target_size + new_width = ( + round(width * (new_height / height) / 14) * 14 + ) # Make divisible by 14 + else: # mode == "crop" + # Original behavior: set width to 518px + new_width = target_size + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / 14) * 14 + + # Resize with new dimensions (width, height) + img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 (only in crop mode) + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + + # For pad mode, pad to make a square of target_size x target_size + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + # Pad with white (value=1.0) + img = torch.nn.functional.pad( + img, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=1.0, + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + if len(shapes) > 1: + print(f"Warning: Found images with different shapes: {shapes}") + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=1.0, + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + + # Ensure correct shape when single image + if len(image_path_list) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + + return images diff --git a/mapanything/models/external/vggt/utils/pose_enc.py b/mapanything/models/external/vggt/utils/pose_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..b83ad40498229f3bae2143a8fc6c742de27a2264 --- /dev/null +++ b/mapanything/models/external/vggt/utils/pose_enc.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .rotation import mat_to_quat, quat_to_mat + + +def extri_intri_to_pose_encoding( + extrinsics, + intrinsics, + image_size_hw=None, # e.g., (256, 512) + pose_encoding_type="absT_quaR_FoV", +): + """Convert camera extrinsics and intrinsics to a compact pose encoding. + + This function transforms camera parameters into a unified pose encoding format, + which can be used for various downstream tasks like pose prediction or representation. + + Args: + extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, + where B is batch size and S is sequence length. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. + The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. + intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. + Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for computing field of view values. For example: (256, 512). + pose_encoding_type (str): Type of pose encoding to use. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + + Returns: + torch.Tensor: Encoded camera pose parameters with shape BxSx9. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + """ + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + + if pose_encoding_type == "absT_quaR_FoV": + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat( + [T, quat, fov_h[..., None], fov_w[..., None]], dim=-1 + ).float() + else: + raise NotImplementedError + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, + image_size_hw=None, # e.g., (256, 512) + pose_encoding_type="absT_quaR_FoV", + build_intrinsics=True, +): + """Convert a pose encoding back to camera extrinsics and intrinsics. + + This function performs the inverse operation of extri_intri_to_pose_encoding, + reconstructing the full camera parameters from the compact encoding. + + Args: + pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, + where B is batch size and S is sequence length. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for reconstructing intrinsics from field of view values. + For example: (256, 512). + pose_encoding_type (str): Type of pose encoding used. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. + If False, only extrinsics are returned and intrinsics will be None. + + Returns: + tuple: (extrinsics, intrinsics) + - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world + transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is + a 3x1 translation vector. + - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, + or None if build_intrinsics is False. Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point, + assumed to be at the center of the image (W/2, H/2). + """ + + intrinsics = None + + if pose_encoding_type == "absT_quaR_FoV": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + if build_intrinsics: + H, W = image_size_hw + fy = (H / 2.0) / torch.tan(fov_h / 2.0) + fx = (W / 2.0) / torch.tan(fov_w / 2.0) + intrinsics = torch.zeros( + pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device + ) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + else: + raise NotImplementedError + + return extrinsics, intrinsics diff --git a/mapanything/models/external/vggt/utils/rotation.py b/mapanything/models/external/vggt/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5e25dc36dfe7fe8472d536105fb449b7044971 --- /dev/null +++ b/mapanything/models/external/vggt/utils/rotation.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d + +import torch +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/mapanything/models/external/vggt/utils/visual_track.py b/mapanything/models/external/vggt/utils/visual_track.py new file mode 100644 index 0000000000000000000000000000000000000000..0d4c314b016cc074e5d6640e4b4af07acc1a1699 --- /dev/null +++ b/mapanything/models/external/vggt/utils/visual_track.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import cv2 +import numpy as np +import torch + + +def color_from_xy(x, y, W, H, cmap_name="hsv"): + """ + Map (x, y) -> color in (R, G, B). + 1) Normalize x,y to [0,1]. + 2) Combine them into a single scalar c in [0,1]. + 3) Use matplotlib's colormap to convert c -> (R,G,B). + + You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). + """ + import matplotlib.cm + import matplotlib.colors + + x_norm = x / max(W - 1, 1) + y_norm = y / max(H - 1, 1) + # Simple combination: + c = (x_norm + y_norm) / 2.0 + + cmap = matplotlib.cm.get_cmap(cmap_name) + # cmap(c) -> (r,g,b,a) in [0,1] + rgba = cmap(c) + r, g, b = rgba[0], rgba[1], rgba[2] + return (r, g, b) # in [0,1], RGB order + + +def get_track_colors_by_position( + tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv" +): + """ + Given all tracks in one sample (b), compute a (N,3) array of RGB color values + in [0,255]. The color is determined by the (x,y) position in the first + visible frame for each track. + + Args: + tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. + vis_mask_b: (S, N) boolean mask; if None, assume all are visible. + image_width, image_height: used for normalizing (x, y). + cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). + + Returns: + track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. + """ + S, N, _ = tracks_b.shape + track_colors = np.zeros((N, 3), dtype=np.uint8) + + if vis_mask_b is None: + # treat all as visible + vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) + + for i in range(N): + # Find first visible frame for track i + visible_frames = torch.where(vis_mask_b[:, i])[0] + if len(visible_frames) == 0: + # track is never visible; just assign black or something + track_colors[i] = (0, 0, 0) + continue + + first_s = int(visible_frames[0].item()) + # use that frame's (x,y) + x, y = tracks_b[first_s, i].tolist() + + # map (x,y) -> (R,G,B) in [0,1] + r, g, b = color_from_xy( + x, y, W=image_width, H=image_height, cmap_name=cmap_name + ) + # scale to [0,255] + r, g, b = int(r * 255), int(g * 255), int(b * 255) + track_colors[i] = (r, g, b) + + return track_colors + + +def visualize_tracks_on_images( + images, + tracks, + track_vis_mask=None, + out_dir="track_visuals_concat_by_xy", + image_format="CHW", # "CHW" or "HWC" + normalize_mode="[0,1]", + cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" + frames_per_row=4, # New parameter for grid layout + save_grid=True, # Flag to control whether to save the grid image +): + """ + Visualizes frames in a grid layout with specified frames per row. + Each track's color is determined by its (x,y) position + in the first visible frame (or frame 0 if always visible). + Finally convert the BGR result to RGB before saving. + Also saves each individual frame as a separate PNG file. + + Args: + images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. + tracks: torch.Tensor (S, N, 2), last dim = (x, y). + track_vis_mask: torch.Tensor (S, N) or None. + out_dir: folder to save visualizations. + image_format: "CHW" or "HWC". + normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 + cmap_name: a matplotlib colormap name for color_from_xy. + frames_per_row: number of frames to display in each row of the grid. + save_grid: whether to save all frames in one grid image. + + Returns: + None (saves images in out_dir). + """ + + if len(tracks.shape) == 4: + tracks = tracks.squeeze(0) + images = images.squeeze(0) + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.squeeze(0) + + import matplotlib + + matplotlib.use("Agg") # for non-interactive (optional) + + os.makedirs(out_dir, exist_ok=True) + + S = images.shape[0] + _, N, _ = tracks.shape # (S, N, 2) + + # Move to CPU + images = images.cpu().clone() + tracks = tracks.cpu().clone() + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.cpu().clone() + + # Infer H, W from images shape + if image_format == "CHW": + # e.g. images[s].shape = (3, H, W) + H, W = images.shape[2], images.shape[3] + else: + # e.g. images[s].shape = (H, W, 3) + H, W = images.shape[1], images.shape[2] + + # Pre-compute the color for each track i based on first visible position + track_colors_rgb = get_track_colors_by_position( + tracks, # shape (S, N, 2) + vis_mask_b=track_vis_mask if track_vis_mask is not None else None, + image_width=W, + image_height=H, + cmap_name=cmap_name, + ) + + # We'll accumulate each frame's drawn image in a list + frame_images = [] + + for s in range(S): + # shape => either (3, H, W) or (H, W, 3) + img = images[s] + + # Convert to (H, W, 3) + if image_format == "CHW": + img = img.permute(1, 2, 0) # (H, W, 3) + # else "HWC", do nothing + + img = img.numpy().astype(np.float32) + + # Scale to [0,255] if needed + if normalize_mode == "[0,1]": + img = np.clip(img, 0, 1) * 255.0 + elif normalize_mode == "[-1,1]": + img = (img + 1.0) * 0.5 * 255.0 + img = np.clip(img, 0, 255.0) + # else no normalization + + # Convert to uint8 + img = img.astype(np.uint8) + + # For drawing in OpenCV, convert to BGR + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + # Draw each visible track + cur_tracks = tracks[s] # shape (N, 2) + if track_vis_mask is not None: + valid_indices = torch.where(track_vis_mask[s])[0] + else: + valid_indices = range(N) + + cur_tracks_np = cur_tracks.numpy() + for i in valid_indices: + x, y = cur_tracks_np[i] + pt = (int(round(x)), int(round(y))) + + # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR + R, G, B = track_colors_rgb[i] + color_bgr = (int(B), int(G), int(R)) + cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) + + # Convert back to RGB for consistent final saving: + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + # Save individual frame + frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") + # Convert to BGR for OpenCV imwrite + frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(frame_path, frame_bgr) + + frame_images.append(img_rgb) + + # Only create and save the grid image if save_grid is True + if save_grid: + # Calculate grid dimensions + num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division + + # Create a grid of images + grid_img = None + for row in range(num_rows): + start_idx = row * frames_per_row + end_idx = min(start_idx + frames_per_row, S) + + # Concatenate this row horizontally + row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) + + # If this row has fewer than frames_per_row images, pad with black + if end_idx - start_idx < frames_per_row: + padding_width = (frames_per_row - (end_idx - start_idx)) * W + padding = np.zeros((H, padding_width, 3), dtype=np.uint8) + row_img = np.concatenate([row_img, padding], axis=1) + + # Add this row to the grid + if grid_img is None: + grid_img = row_img + else: + grid_img = np.concatenate([grid_img, row_img], axis=0) + + out_path = os.path.join(out_dir, "tracks_grid.png") + # Convert back to BGR for OpenCV imwrite + grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_path, grid_img_bgr) + print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") + + print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png") diff --git a/mapanything/models/mapanything/__init__.py b/mapanything/models/mapanything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc21520d217b751b6cb98a35674ca4af997c65f --- /dev/null +++ b/mapanything/models/mapanything/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mapanything.models.mapanything.ablations import MapAnythingAblations +from mapanything.models.mapanything.model import MapAnything +from mapanything.models.mapanything.modular_dust3r import ModularDUSt3R + +__all__ = [ + "MapAnything", + "MapAnythingAblations", + "ModularDUSt3R", +] diff --git a/mapanything/models/mapanything/__pycache__/__init__.cpython-312.pyc b/mapanything/models/mapanything/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa9f43ba532f38ee0f1ca02c08a524bdae34d5f6 Binary files /dev/null and b/mapanything/models/mapanything/__pycache__/__init__.cpython-312.pyc differ diff --git a/mapanything/models/mapanything/__pycache__/ablations.cpython-312.pyc b/mapanything/models/mapanything/__pycache__/ablations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d82f6c1d7a1744a8d7828e79106fe0d5ed9de61d Binary files /dev/null and b/mapanything/models/mapanything/__pycache__/ablations.cpython-312.pyc differ diff --git a/mapanything/models/mapanything/__pycache__/model.cpython-312.pyc b/mapanything/models/mapanything/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14e749123a0eb4608b34dac7cbd11b3ba319f3e1 Binary files /dev/null and b/mapanything/models/mapanything/__pycache__/model.cpython-312.pyc differ diff --git a/mapanything/models/mapanything/__pycache__/modular_dust3r.cpython-312.pyc b/mapanything/models/mapanything/__pycache__/modular_dust3r.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfa468cd3d0f290b9ef7430de5cca69c48ba9899 Binary files /dev/null and b/mapanything/models/mapanything/__pycache__/modular_dust3r.cpython-312.pyc differ diff --git a/mapanything/models/mapanything/ablations.py b/mapanything/models/mapanything/ablations.py new file mode 100644 index 0000000000000000000000000000000000000000..5b23ffe1b62acaead4eb11ad8336e3d71f4ae102 --- /dev/null +++ b/mapanything/models/mapanything/ablations.py @@ -0,0 +1,1660 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MapAnything Ablation model classes defined using UniCeption modules. +""" + +from functools import partial +from typing import Callable, Dict, Type, Union + +import torch +import torch.nn as nn + +from mapanything.utils.geometry import ( + apply_log_to_norm, + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + normalize_depth_using_non_zero_pixels, + normalize_pose_translations, + transform_pose_using_quats_and_trans_2_to_1, +) +from uniception.models.encoders import ( + encoder_factory, + EncoderGlobalRepInput, + ViTEncoderInput, + ViTEncoderNonImageInput, +) +from uniception.models.info_sharing.alternating_attention_transformer import ( + MultiViewAlternatingAttentionTransformer, + MultiViewAlternatingAttentionTransformerIFR, +) +from uniception.models.info_sharing.base import MultiViewTransformerInput +from uniception.models.info_sharing.cross_attention_transformer import ( + MultiViewCrossAttentionTransformer, + MultiViewCrossAttentionTransformerIFR, +) +from uniception.models.info_sharing.global_attention_transformer import ( + MultiViewGlobalAttentionTransformer, + MultiViewGlobalAttentionTransformerIFR, +) +from uniception.models.libs.croco.pos_embed import RoPE2D +from uniception.models.prediction_heads.adaptors import ( + CamTranslationPlusQuatsAdaptor, + PointMapAdaptor, + PointMapPlusRayDirectionsPlusDepthAdaptor, + PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor, + PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor, + PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor, + PointMapWithConfidenceAdaptor, + PointMapWithConfidenceAndMaskAdaptor, + PointMapWithMaskAdaptor, + RayDirectionsPlusDepthAdaptor, + RayDirectionsPlusDepthWithConfidenceAdaptor, + RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor, + RayDirectionsPlusDepthWithMaskAdaptor, + RayMapPlusDepthAdaptor, + RayMapPlusDepthWithConfidenceAdaptor, + RayMapPlusDepthWithConfidenceAndMaskAdaptor, + RayMapPlusDepthWithMaskAdaptor, +) +from uniception.models.prediction_heads.base import ( + AdaptorInput, + PredictionHeadInput, + PredictionHeadLayeredInput, +) +from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor +from uniception.models.prediction_heads.linear import LinearFeature +from uniception.models.prediction_heads.pose_head import PoseHead + +# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) +if hasattr(torch.backends.cuda, "matmul") and hasattr( + torch.backends.cuda.matmul, "allow_tf32" +): + torch.backends.cuda.matmul.allow_tf32 = True + + +class MapAnythingAblations(nn.Module): + "Modular MapAnything Multi-View model class with no scale token." + + def __init__( + self, + name: str, + encoder_config: Dict, + info_sharing_config: Dict, + pred_head_config: Dict, + geometric_input_config: Dict, + fusion_norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial( + nn.LayerNorm, eps=1e-6 + ), + pretrained_checkpoint_path: str = None, + load_specific_pretrained_submodules: bool = False, + specific_pretrained_submodules: list = [], + torch_hub_force_reload: bool = False, + ): + """ + Multi-view model containing an image encoder followed by a multi-view attention transformer and respective downstream heads. + The goal is to output scene representation directly in view 0's frame. + + Args: + name (str): Name of the model. + encoder_config (Dict): Configuration for the encoder. + info_sharing_config (Dict): Configuration for the multi-view attention transformer. + pred_head_config (Dict): Configuration for the prediction heads. + pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None) + load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False) + specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: []) + torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False) + """ + super().__init__() + + # Initialize the attributes + self.name = name + self.encoder_config = encoder_config + self.info_sharing_config = info_sharing_config + self.pred_head_config = pred_head_config + self.geometric_input_config = geometric_input_config + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.load_specific_pretrained_submodules = load_specific_pretrained_submodules + self.specific_pretrained_submodules = specific_pretrained_submodules + self.torch_hub_force_reload = torch_hub_force_reload + self.class_init_args = { + "name": self.name, + "encoder_config": self.encoder_config, + "info_sharing_config": self.info_sharing_config, + "pred_head_config": self.pred_head_config, + "geometric_input_config": self.geometric_input_config, + "pretrained_checkpoint_path": self.pretrained_checkpoint_path, + "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules, + "specific_pretrained_submodules": self.specific_pretrained_submodules, + "torch_hub_force_reload": self.torch_hub_force_reload, + } + + # Get relevant parameters from the configs + self.info_sharing_type = info_sharing_config["model_type"] + self.info_sharing_return_type = info_sharing_config["model_return_type"] + self.pred_head_type = pred_head_config["type"] + + # Initialize image encoder + if self.encoder_config["uses_torch_hub"]: + self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload + # Create a copy of the config before deleting the key to preserve it for serialization + encoder_config_copy = self.encoder_config.copy() + del encoder_config_copy["uses_torch_hub"] + self.encoder = encoder_factory(**encoder_config_copy) + + # Initialize the encoder for ray directions + ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"] + ray_dirs_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + ray_dirs_encoder_config["patch_size"] = self.encoder.patch_size + self.ray_dirs_encoder = encoder_factory(**ray_dirs_encoder_config) + + # Initialize the encoder for depth (normalized per view and values after normalization are scaled logarithmically) + depth_encoder_config = self.geometric_input_config["depth_encoder_config"] + depth_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + depth_encoder_config["patch_size"] = self.encoder.patch_size + self.depth_encoder = encoder_factory(**depth_encoder_config) + + # Initialize the encoder for log scale factor of depth + depth_scale_encoder_config = self.geometric_input_config["scale_encoder_config"] + depth_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.depth_scale_encoder = encoder_factory(**depth_scale_encoder_config) + + # Initialize the encoder for camera rotation + cam_rot_encoder_config = self.geometric_input_config["cam_rot_encoder_config"] + cam_rot_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_rot_encoder = encoder_factory(**cam_rot_encoder_config) + + # Initialize the encoder for camera translation (normalized across all provided camera translations) + cam_trans_encoder_config = self.geometric_input_config[ + "cam_trans_encoder_config" + ] + cam_trans_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_trans_encoder = encoder_factory(**cam_trans_encoder_config) + + # Initialize the encoder for log scale factor of camera translation + cam_trans_scale_encoder_config = self.geometric_input_config[ + "scale_encoder_config" + ] + cam_trans_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_trans_scale_encoder = encoder_factory(**cam_trans_scale_encoder_config) + + # Initialize the fusion norm layer + self.fusion_norm_layer = fusion_norm_layer(self.encoder.enc_embed_dim) + + # Initialize the info sharing module (Multi-View Transformer) + self._initialize_info_sharing(info_sharing_config) + + # Initialize the prediction heads + self._initialize_prediction_heads(pred_head_config) + + # Initialize the final adaptors + self._initialize_adaptors(pred_head_config) + + # Load pretrained weights + self._load_pretrained_weights() + + def _initialize_info_sharing(self, info_sharing_config): + """ + Initialize the information sharing module based on the configuration. + + This method sets up the custom positional encoding if specified and initializes + the appropriate multi-view transformer based on the configuration type. + + Args: + info_sharing_config (Dict): Configuration for the multi-view attention transformer. + Should contain 'custom_positional_encoding', 'model_type', and 'model_return_type'. + + Returns: + None + + Raises: + ValueError: If invalid configuration options are provided. + """ + # Initialize Custom Positional Encoding if required + custom_positional_encoding = info_sharing_config["custom_positional_encoding"] + if custom_positional_encoding is not None: + if isinstance(custom_positional_encoding, str): + print( + f"Using custom positional encoding for multi-view attention transformer: {custom_positional_encoding}" + ) + if custom_positional_encoding.startswith("RoPE"): + rope_freq = float(custom_positional_encoding[len("RoPE") :]) + print(f"RoPE frequency: {rope_freq}") + self.custom_positional_encoding = RoPE2D(freq=rope_freq) + else: + raise ValueError( + f"Invalid custom_positional_encoding: {custom_positional_encoding}." + ) + elif isinstance(custom_positional_encoding, Callable): + print( + "Using callable function as custom positional encoding for multi-view attention transformer." + ) + self.custom_positional_encoding = custom_positional_encoding + else: + self.custom_positional_encoding = None + + # Add dependencies to info_sharing_config + info_sharing_config["module_args"]["input_embed_dim"] = ( + self.encoder.enc_embed_dim + ) + info_sharing_config["module_args"]["custom_positional_encoding"] = ( + self.custom_positional_encoding + ) + + # Initialize Multi-View Transformer + if self.info_sharing_return_type == "no_intermediate_features": + # Returns only normalized last layer features + # Initialize multi-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformer( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + elif self.info_sharing_return_type == "intermediate_features": + # Returns intermediate features and normalized last layer features + # Initialize mulit-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + # Assess if the DPT needs to use encoder features + if len(self.info_sharing.indices) == 2: + self.use_encoder_features_for_dpt = True + elif len(self.info_sharing.indices) == 3: + self.use_encoder_features_for_dpt = False + else: + raise ValueError( + "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices." + ) + else: + raise ValueError( + f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']" + ) + + def _initialize_prediction_heads(self, pred_head_config): + """ + Initialize the prediction heads based on the prediction head configuration. + + This method configures and initializes the appropriate prediction heads based on the + specified prediction head type (linear, DPT, or DPT+pose). It sets up the necessary + dependencies and creates the required model components. + + Args: + pred_head_config (Dict): Configuration for the prediction heads. + + Returns: + None + + Raises: + ValueError: If an invalid pred_head_type is provided. + """ + # Add dependencies to prediction head config + pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size + if self.pred_head_type == "linear": + pred_head_config["feature_head"]["input_feature_dim"] = ( + self.info_sharing.dim + ) + elif "dpt" in self.pred_head_type: + # Add dependencies for DPT & Regressor head + if self.use_encoder_features_for_dpt: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.encoder.enc_embed_dim + ] + [self.info_sharing.dim] * 3 + else: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.info_sharing.dim + ] * 4 + pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[ + "feature_head" + ]["feature_dim"] + # Add dependencies for Pose head if required + if "pose" in self.pred_head_type: + pred_head_config["pose_head"]["patch_size"] = self.encoder.patch_size + pred_head_config["pose_head"]["input_feature_dim"] = ( + self.info_sharing.dim + ) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + # Initialize Prediction Heads + if self.pred_head_type == "linear": + # Initialize Dense Prediction Head for all views + self.dense_head = LinearFeature(**pred_head_config["feature_head"]) + elif "dpt" in self.pred_head_type: + # Initialize Dense Prediction Head for all views + self.dpt_feature_head = DPTFeature(**pred_head_config["feature_head"]) + self.dpt_regressor_head = DPTRegressionProcessor( + **pred_head_config["regressor_head"] + ) + self.dense_head = nn.Sequential( + self.dpt_feature_head, self.dpt_regressor_head + ) + # Initialize Pose Head for all views if required + if "pose" in self.pred_head_type: + self.pose_head = PoseHead(**pred_head_config["pose_head"]) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + def _initialize_adaptors(self, pred_head_config): + """ + Initialize the adaptors based on the prediction head configuration. + + This method sets up the appropriate adaptors for different scene representation types, + such as pointmaps, ray maps with depth, or ray directions with depth and pose. + + Args: + pred_head_config (Dict): Configuration for the prediction heads including adaptor type. + + Returns: + None + + Raises: + ValueError: If an invalid adaptor_type is provided. + AssertionError: If ray directions + depth + pose is used with an incompatible head type. + """ + if pred_head_config["adaptor_type"] == "pointmap": + self.dense_adaptor = PointMapAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "pointmap" + elif pred_head_config["adaptor_type"] == "pointmap+confidence": + self.dense_adaptor = PointMapWithConfidenceAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "pointmap+confidence" + elif pred_head_config["adaptor_type"] == "pointmap+mask": + self.dense_adaptor = PointMapWithMaskAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "pointmap+mask" + elif pred_head_config["adaptor_type"] == "pointmap+confidence+mask": + self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "pointmap+confidence+mask" + elif pred_head_config["adaptor_type"] == "raymap+depth": + self.dense_adaptor = RayMapPlusDepthAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "raymap+depth" + elif pred_head_config["adaptor_type"] == "raymap+depth+confidence": + self.dense_adaptor = RayMapPlusDepthWithConfidenceAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+confidence" + elif pred_head_config["adaptor_type"] == "raymap+depth+mask": + self.dense_adaptor = RayMapPlusDepthWithMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+mask" + elif pred_head_config["adaptor_type"] == "raymap+depth+confidence+mask": + self.dense_adaptor = RayMapPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+confidence+mask" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+confidence" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+mask" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence+mask": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+confidence+mask" + elif pred_head_config["adaptor_type"] == "campointmap+pose": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapAdaptor(**pred_head_config["dpt_adaptor"]) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose" + elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+confidence" + elif pred_head_config["adaptor_type"] == "campointmap+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+mask" + elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence+mask": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+confidence+mask" + elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose": + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose" + elif ( + pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+confidence" + ): + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = ( + PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence" + elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+mask" + elif ( + pred_head_config["adaptor_type"] + == "pointmap+raydirs+depth+pose+confidence+mask" + ): + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = ( + PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence+mask" + else: + raise ValueError( + f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. \ + Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \ + 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \ + 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \ + 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']" + ) + + def _load_pretrained_weights(self): + """ + Load pretrained weights from a checkpoint file. + + If load_specific_pretrained_submodules is True, only loads weights for the specified submodules. + Otherwise, loads all weights from the checkpoint. + + Returns: + None + """ + if self.pretrained_checkpoint_path is not None: + if not self.load_specific_pretrained_submodules: + print( + f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + else: + print( + f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} for specific submodules: {self.specific_pretrained_submodules} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + filtered_ckpt = {} + for ckpt_key, ckpt_value in ckpt["model"].items(): + for submodule in self.specific_pretrained_submodules: + if ckpt_key.startswith(submodule): + filtered_ckpt[ckpt_key] = ckpt_value + print(self.load_state_dict(filtered_ckpt, strict=False)) + + def _encode_n_views(self, views): + """ + Encode all the input views (batch of images) in a single forward pass. + Assumes all the input views have the same image shape, batch size, and data normalization type. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + + Returns: + List[torch.Tensor]: A list containing the encoded features for all N views. + """ + num_views = len(views) + data_norm_type = views[0]["data_norm_type"][0] + imgs_list = [view["img"] for view in views] + all_imgs_across_views = torch.cat(imgs_list, dim=0) + encoder_input = ViTEncoderInput( + image=all_imgs_across_views, data_norm_type=data_norm_type + ) + encoder_output = self.encoder(encoder_input) + all_encoder_features_across_views = encoder_output.features.chunk( + num_views, dim=0 + ) + + return all_encoder_features_across_views + + def _compute_pose_quats_and_trans_for_across_views_in_ref_view( + self, + views, + num_views, + device, + dtype, + batch_size_per_view, + per_sample_cam_input_mask, + ): + """ + Compute the pose quats and trans for all the views in the frame of the reference view 0. + Returns identity pose for views where the camera input mask is False or the pose is not provided. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + device (torch.device): Device to use for the computation. + dtype (torch.dtype): Data type to use for the computation. + per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask. + + Returns: + torch.Tensor: A tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4) + torch.Tensor: A tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3) + torch.Tensor: A tensor containing the per sample camera input mask. + """ + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + pose_quats_non_ref_views = [] + pose_trans_non_ref_views = [] + pose_quats_ref_view_0 = [] + pose_trans_ref_view_0 = [] + for view_idx in range(num_views): + per_sample_cam_input_mask_for_curr_view = per_sample_cam_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] + if ( + "camera_pose_quats" in views[view_idx] + and "camera_pose_trans" in views[view_idx] + and per_sample_cam_input_mask_for_curr_view.any() + ): + # Get the camera pose quats and trans for the current view + cam_pose_quats = views[view_idx]["camera_pose_quats"][ + per_sample_cam_input_mask_for_curr_view + ] + cam_pose_trans = views[view_idx]["camera_pose_trans"][ + per_sample_cam_input_mask_for_curr_view + ] + # Append to the list + pose_quats_non_ref_views.append(cam_pose_quats) + pose_trans_non_ref_views.append(cam_pose_trans) + # Get the camera pose quats and trans for the reference view 0 + cam_pose_quats = views[0]["camera_pose_quats"][ + per_sample_cam_input_mask_for_curr_view + ] + cam_pose_trans = views[0]["camera_pose_trans"][ + per_sample_cam_input_mask_for_curr_view + ] + # Append to the list + pose_quats_ref_view_0.append(cam_pose_quats) + pose_trans_ref_view_0.append(cam_pose_trans) + else: + per_sample_cam_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + + # Initialize the pose quats and trans for all views as identity + pose_quats_across_views = torch.tensor( + [0.0, 0.0, 0.0, 1.0], dtype=dtype, device=device + ).repeat(batch_size_per_view * num_views, 1) # (q_x, q_y, q_z, q_w) + pose_trans_across_views = torch.zeros( + (batch_size_per_view * num_views, 3), dtype=dtype, device=device + ) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + if len(pose_quats_non_ref_views) > 0: + # Stack the pose quats and trans for all the non-reference views and reference view 0 + pose_quats_non_ref_views = torch.cat(pose_quats_non_ref_views, dim=0) + pose_trans_non_ref_views = torch.cat(pose_trans_non_ref_views, dim=0) + pose_quats_ref_view_0 = torch.cat(pose_quats_ref_view_0, dim=0) + pose_trans_ref_view_0 = torch.cat(pose_trans_ref_view_0, dim=0) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + ( + pose_quats_non_ref_views_in_ref_view_0, + pose_trans_non_ref_views_in_ref_view_0, + ) = transform_pose_using_quats_and_trans_2_to_1( + pose_quats_ref_view_0, + pose_trans_ref_view_0, + pose_quats_non_ref_views, + pose_trans_non_ref_views, + ) + + # Update the pose quats and trans for all the non-reference views + pose_quats_across_views[per_sample_cam_input_mask] = ( + pose_quats_non_ref_views_in_ref_view_0.to(dtype=dtype) + ) + pose_trans_across_views[per_sample_cam_input_mask] = ( + pose_trans_non_ref_views_in_ref_view_0.to(dtype=dtype) + ) + + return ( + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ) + + def _encode_and_fuse_ray_dirs( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_ray_dirs_input_mask, + ): + """ + Encode the ray directions for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + per_sample_ray_dirs_input_mask (torch.Tensor): Tensor containing the per sample ray direction input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Get the height and width of the images + _, _, height, width = views[0]["img"].shape + + # Get the ray directions for all the views where info is provided and the ray direction input mask is True + ray_dirs_list = [] + for view_idx in range(num_views): + per_sample_ray_dirs_input_mask_for_curr_view = ( + per_sample_ray_dirs_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] + ) + ray_dirs_for_curr_view = torch.zeros( + (batch_size_per_view, height, width, 3), + dtype=all_encoder_features_across_views.dtype, + device=all_encoder_features_across_views.device, + ) + if ( + "ray_directions_cam" in views[view_idx] + and per_sample_ray_dirs_input_mask_for_curr_view.any() + ): + ray_dirs_for_curr_view[per_sample_ray_dirs_input_mask_for_curr_view] = ( + views[view_idx]["ray_directions_cam"][ + per_sample_ray_dirs_input_mask_for_curr_view + ] + ) + else: + per_sample_ray_dirs_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + ray_dirs_list.append(ray_dirs_for_curr_view) + + # Stack the ray directions for all the views and permute to (B * V, C, H, W) + ray_dirs = torch.cat(ray_dirs_list, dim=0) # (B * V, H, W, 3) + ray_dirs = ray_dirs.permute(0, 3, 1, 2).contiguous() # (B * V, 3, H, W) + + # Encode the ray directions + ray_dirs_features_across_views = self.ray_dirs_encoder( + ViTEncoderNonImageInput(data=ray_dirs) + ).features + + # Fuse the ray direction features with the other encoder features (zero out the features where the ray direction input mask is False) + ray_dirs_features_across_views = ( + ray_dirs_features_across_views + * per_sample_ray_dirs_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + ) + all_encoder_features_across_views = ( + all_encoder_features_across_views + ray_dirs_features_across_views + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_depths( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_depth_input_mask, + ): + """ + Encode the z depths for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + per_sample_depth_input_mask (torch.Tensor): Tensor containing the per sample depth input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Get the device and height and width of the images + device = all_encoder_features_across_views.device + _, _, height, width = views[0]["img"].shape + + # Decide to use randomly sampled sparse depth or dense depth + if torch.rand(1) < self.geometric_input_config["sparse_depth_prob"]: + use_sparse_depth = True + else: + use_sparse_depth = False + + # Get the depths for all the views + depth_list = [] + depth_norm_factors_list = [] + metric_scale_depth_mask_list = [] + for view_idx in range(num_views): + # Get the input mask for current view + per_sample_depth_input_mask_for_curr_view = per_sample_depth_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] + depth_for_curr_view = torch.zeros( + (batch_size_per_view, height, width, 1), + dtype=all_encoder_features_across_views.dtype, + device=device, + ) + depth_norm_factor_for_curr_view = torch.zeros( + (batch_size_per_view), + dtype=all_encoder_features_across_views.dtype, + device=device, + ) + metric_scale_mask_for_curr_view = torch.zeros( + (batch_size_per_view), + dtype=torch.bool, + device=device, + ) + if ( + "depth_along_ray" in views[view_idx] + ) and per_sample_depth_input_mask_for_curr_view.any(): + # Get depth for current view + depth_for_curr_view_input = views[view_idx]["depth_along_ray"][ + per_sample_depth_input_mask_for_curr_view + ] + # Get the metric scale mask + if "is_metric_scale" in views[view_idx]: + metric_scale_mask = views[view_idx]["is_metric_scale"][ + per_sample_depth_input_mask_for_curr_view + ] + else: + metric_scale_mask = torch.zeros( + depth_for_curr_view_input.shape[0], + dtype=torch.bool, + device=device, + ) + # Turn off indication of metric scale samples based on the depth_scale_norm_all_prob + depth_scale_norm_all_mask = ( + torch.rand(metric_scale_mask.shape[0]) + < self.geometric_input_config["depth_scale_norm_all_prob"] + ) + if depth_scale_norm_all_mask.any(): + metric_scale_mask[depth_scale_norm_all_mask] = False + # Assign the metric scale mask to the respective indices + metric_scale_mask_for_curr_view[ + per_sample_depth_input_mask_for_curr_view + ] = metric_scale_mask + # Sparsely sample the depth if required + if use_sparse_depth: + # Create a mask of ones + sparsification_mask = torch.ones_like( + depth_for_curr_view_input, device=device + ) + # Create a mask for valid pixels (depth > 0) + valid_pixel_mask = depth_for_curr_view_input > 0 + # Calculate the number of valid pixels + num_valid_pixels = valid_pixel_mask.sum().item() + # Calculate the number of valid pixels to set to zero + num_to_zero = int( + num_valid_pixels + * self.geometric_input_config["sparsification_removal_percent"] + ) + if num_to_zero > 0: + # Get the indices of valid pixels + valid_indices = valid_pixel_mask.nonzero(as_tuple=True) + # Randomly select indices to zero out + indices_to_zero = torch.randperm(num_valid_pixels)[:num_to_zero] + # Set selected valid indices to zero in the mask + sparsification_mask[ + valid_indices[0][indices_to_zero], + valid_indices[1][indices_to_zero], + valid_indices[2][indices_to_zero], + valid_indices[3][indices_to_zero], + ] = 0 + # Apply the mask on the depth + depth_for_curr_view_input = ( + depth_for_curr_view_input * sparsification_mask + ) + # Normalize the depth + scaled_depth_for_curr_view_input, depth_norm_factor = ( + normalize_depth_using_non_zero_pixels( + depth_for_curr_view_input, return_norm_factor=True + ) + ) + # Assign the depth and depth norm factor to the respective indices + depth_for_curr_view[per_sample_depth_input_mask_for_curr_view] = ( + scaled_depth_for_curr_view_input + ) + depth_norm_factor_for_curr_view[ + per_sample_depth_input_mask_for_curr_view + ] = depth_norm_factor + else: + per_sample_depth_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + # Append the depths, depth norm factor and metric scale mask for the current view + depth_list.append(depth_for_curr_view) + depth_norm_factors_list.append(depth_norm_factor_for_curr_view) + metric_scale_depth_mask_list.append(metric_scale_mask_for_curr_view) + + # Stack the depths for all the views and permute to (B * V, C, H, W) + depths = torch.cat(depth_list, dim=0) # (B * V, H, W, 1) + depths = apply_log_to_norm( + depths + ) # Scale logarithimically (norm is computed along last dim) + depths = depths.permute(0, 3, 1, 2).contiguous() # (B * V, 1, H, W) + # Encode the depths using the depth encoder + depth_features_across_views = self.depth_encoder( + ViTEncoderNonImageInput(data=depths) + ).features + # Zero out the depth features where the depth input mask is False + depth_features_across_views = ( + depth_features_across_views + * per_sample_depth_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + ) + + # Stack the depth norm factors for all the views + depth_norm_factors = torch.cat(depth_norm_factors_list, dim=0) # (B * V, ) + # Encode the depth norm factors using the log scale encoder for depth + log_depth_norm_factors = torch.log(depth_norm_factors + 1e-8) # (B * V, ) + depth_scale_features_across_views = self.depth_scale_encoder( + EncoderGlobalRepInput(data=log_depth_norm_factors.unsqueeze(-1)) + ).features + # Zero out the depth scale features where the depth input mask is False + depth_scale_features_across_views = ( + depth_scale_features_across_views + * per_sample_depth_input_mask.unsqueeze(-1) + ) + # Stack the metric scale mask for all the views + metric_scale_depth_mask = torch.cat( + metric_scale_depth_mask_list, dim=0 + ) # (B * V, ) + # Zero out the depth scale features where the metric scale mask is False + # Scale encoding is only provided for metric scale samples + depth_scale_features_across_views = ( + depth_scale_features_across_views * metric_scale_depth_mask.unsqueeze(-1) + ) + + # Fuse the depth features & depth scale features with the other encoder features + all_encoder_features_across_views = ( + all_encoder_features_across_views + + depth_features_across_views + + depth_scale_features_across_views.unsqueeze(-1).unsqueeze(-1) + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_cam_quats_and_trans( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ): + """ + Encode the camera quats and trans for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + pose_quats_across_views (torch.Tensor): Tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4) + pose_trans_across_views (torch.Tensor): Tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3) + per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Encode the pose quats + pose_quats_features_across_views = self.cam_rot_encoder( + EncoderGlobalRepInput(data=pose_quats_across_views) + ).features + # Zero out the pose quat features where the camera input mask is False + pose_quats_features_across_views = ( + pose_quats_features_across_views * per_sample_cam_input_mask.unsqueeze(-1) + ) + + # Get the metric scale mask for all samples + device = all_encoder_features_across_views.device + metric_scale_pose_trans_mask = torch.zeros( + (batch_size_per_view * num_views), dtype=torch.bool, device=device + ) + for view_idx in range(num_views): + if "is_metric_scale" in views[view_idx]: + # Get the metric scale mask for the input pose priors + metric_scale_mask = views[view_idx]["is_metric_scale"] + else: + metric_scale_mask = torch.zeros( + batch_size_per_view, dtype=torch.bool, device=device + ) + metric_scale_pose_trans_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] = metric_scale_mask + + # Turn off indication of metric scale samples based on the pose_scale_norm_all_prob + pose_norm_all_mask = ( + torch.rand(batch_size_per_view * num_views) + < self.geometric_input_config["pose_scale_norm_all_prob"] + ) + if pose_norm_all_mask.any(): + metric_scale_pose_trans_mask[pose_norm_all_mask] = False + + # Get the scale norm factor for all the samples and scale the pose translations + pose_trans_across_views = torch.split( + pose_trans_across_views, batch_size_per_view, dim=0 + ) # Split into num_views chunks + pose_trans_across_views = torch.stack( + pose_trans_across_views, dim=1 + ) # Stack the views along a new dimension (batch_size_per_view, num_views, 3) + scaled_pose_trans_across_views, pose_trans_norm_factors = ( + normalize_pose_translations( + pose_trans_across_views, return_norm_factor=True + ) + ) + + # Resize the pose translation back to (batch_size_per_view * num_views, 3) and extend the norm factor to (batch_size_per_view * num_views, 1) + scaled_pose_trans_across_views = scaled_pose_trans_across_views.unbind( + dim=1 + ) # Convert back to list of views, where each view has batch_size_per_view tensor + scaled_pose_trans_across_views = torch.cat( + scaled_pose_trans_across_views, dim=0 + ) # Concatenate back to (batch_size_per_view * num_views, 3) + pose_trans_norm_factors_across_views = pose_trans_norm_factors.unsqueeze( + -1 + ).repeat(num_views, 1) # (B, ) -> (B * V, 1) + + # Encode the pose trans + pose_trans_features_across_views = self.cam_trans_encoder( + EncoderGlobalRepInput(data=scaled_pose_trans_across_views) + ).features + # Zero out the pose trans features where the camera input mask is False + pose_trans_features_across_views = ( + pose_trans_features_across_views * per_sample_cam_input_mask.unsqueeze(-1) + ) + + # Encode the pose translation norm factors using the log scale encoder for pose trans + log_pose_trans_norm_factors_across_views = torch.log( + pose_trans_norm_factors_across_views + 1e-8 + ) + pose_trans_scale_features_across_views = self.cam_trans_scale_encoder( + EncoderGlobalRepInput(data=log_pose_trans_norm_factors_across_views) + ).features + # Zero out the pose trans scale features where the camera input mask is False + pose_trans_scale_features_across_views = ( + pose_trans_scale_features_across_views + * per_sample_cam_input_mask.unsqueeze(-1) + ) + # Zero out the pose trans scale features where the metric scale mask is False + # Scale encoding is only provided for metric scale samples + pose_trans_scale_features_across_views = ( + pose_trans_scale_features_across_views + * metric_scale_pose_trans_mask.unsqueeze(-1) + ) + + # Fuse the pose quat features, pose trans features, pose trans scale features and pose trans type PE features with the other encoder features + all_encoder_features_across_views = ( + all_encoder_features_across_views + + pose_quats_features_across_views.unsqueeze(-1).unsqueeze(-1) + + pose_trans_features_across_views.unsqueeze(-1).unsqueeze(-1) + + pose_trans_scale_features_across_views.unsqueeze(-1).unsqueeze(-1) + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_optional_geometric_inputs( + self, views, all_encoder_features_across_views_list + ): + """ + Encode all the input optional geometric modalities and fuses it with the image encoder features in a single forward pass. + Assumes all the input views have the same shape and batch size. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + all_encoder_features_across_views (List[torch.Tensor]): List of tensors containing the encoded image features for all N views. + + Returns: + List[torch.Tensor]: A list containing the encoded features for all N views. + """ + num_views = len(views) + batch_size_per_view, _, _, _ = views[0]["img"].shape + device = all_encoder_features_across_views_list[0].device + dtype = all_encoder_features_across_views_list[0].dtype + all_encoder_features_across_views = torch.cat( + all_encoder_features_across_views_list, dim=0 + ) + + # Get the overall input mask for all the views + overall_geometric_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["overall_prob"] + ) + overall_geometric_input_mask = overall_geometric_input_mask.repeat(num_views) + + # Get the per sample input mask after dropout + # Per sample input mask is in view-major order so that index v*B + b in each mask corresponds to sample b of view v: (B * V) + per_sample_geometric_input_mask = torch.rand( + batch_size_per_view * num_views, device=device + ) < (1 - self.geometric_input_config["dropout_prob"]) + per_sample_geometric_input_mask = ( + per_sample_geometric_input_mask & overall_geometric_input_mask + ) + + # Get the ray direction input mask + per_sample_ray_dirs_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["ray_dirs_prob"] + ) + per_sample_ray_dirs_input_mask = per_sample_ray_dirs_input_mask.repeat( + num_views + ) + per_sample_ray_dirs_input_mask = ( + per_sample_ray_dirs_input_mask & per_sample_geometric_input_mask + ) + + # Get the depth input mask + per_sample_depth_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["depth_prob"] + ) + per_sample_depth_input_mask = per_sample_depth_input_mask.repeat(num_views) + per_sample_depth_input_mask = ( + per_sample_depth_input_mask & per_sample_geometric_input_mask + ) + + # Get the camera input mask + per_sample_cam_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["cam_prob"] + ) + per_sample_cam_input_mask = per_sample_cam_input_mask.repeat(num_views) + per_sample_cam_input_mask = ( + per_sample_cam_input_mask & per_sample_geometric_input_mask + ) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + # Returned pose quats and trans represent identity pose for views/samples where the camera input mask is False + pose_quats_across_views, pose_trans_across_views, per_sample_cam_input_mask = ( + self._compute_pose_quats_and_trans_for_across_views_in_ref_view( + views, + num_views, + device, + dtype, + batch_size_per_view, + per_sample_cam_input_mask, + ) + ) + + # Encode the ray directions and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_ray_dirs( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_ray_dirs_input_mask, + ) + + # Encode the depths and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_depths( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_depth_input_mask, + ) + + # Encode the cam quat and trans and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_cam_quats_and_trans( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ) + + # Normalize the fused features (permute -> normalize -> permute) + all_encoder_features_across_views = all_encoder_features_across_views.permute( + 0, 2, 3, 1 + ).contiguous() + all_encoder_features_across_views = self.fusion_norm_layer( + all_encoder_features_across_views + ) + all_encoder_features_across_views = all_encoder_features_across_views.permute( + 0, 3, 1, 2 + ).contiguous() + + # Split the batched views into individual views + fused_all_encoder_features_across_views = ( + all_encoder_features_across_views.chunk(num_views, dim=0) + ) + + return fused_all_encoder_features_across_views + + def forward(self, views): + """ + Forward pass performing the following operations: + 1. Encodes the N input views (images). + 2. Encodes the optional geometric inputs (ray directions, depths, camera rotations, camera translations). + 3. Fuses the encoded features from the N input views and the optional geometric inputs using addition and normalization. + 4. Information sharing between the encoded features using a multi-view attention transformer. + 5. Passes the final features through the prediction heads. + 6. Returns the processed final outputs for N views. + + Assumption: + - All the input views have the same image shape. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). Input images must be normalized based on the data norm type of image encoder. + "data_norm_type" (list): [model.encoder.data_norm_type] + + Returns: + List[dict]: A list containing the final outputs for all N views. + """ + # Get input shape of the images, number of views, and batch size per view + batch_size_per_view, _, height, width = views[0]["img"].shape + img_shape = (int(height), int(width)) + num_views = len(views) + + # Run the encoder on all the input views + all_encoder_features_across_views = self._encode_n_views(views) + + # Encode the optional geometric inputs and fuse with the encoded features from the N input views + # Use high precision to prevent NaN values after layer norm in dense representation encoder (due to high variance in last dim of features) + with torch.autocast("cuda", enabled=False): + all_encoder_features_across_views = ( + self._encode_and_fuse_optional_geometric_inputs( + views, all_encoder_features_across_views + ) + ) + + # Combine all images into view-centric representation + info_sharing_input = MultiViewTransformerInput( + features=all_encoder_features_across_views + ) + if self.info_sharing_return_type == "no_intermediate_features": + final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input) + elif self.info_sharing_return_type == "intermediate_features": + ( + final_info_sharing_multi_view_feat, + intermediate_info_sharing_multi_view_feat, + ) = self.info_sharing(info_sharing_input) + + if self.pred_head_type == "linear": + # Stack the features for all views + dense_head_inputs = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + elif self.pred_head_type in ["dpt", "dpt+pose"]: + # Get the list of features for all views + dense_head_inputs_list = [] + if self.use_encoder_features_for_dpt: + # Stack all the image encoder features for all views + stacked_encoder_features = torch.cat( + all_encoder_features_across_views, dim=0 + ) + dense_head_inputs_list.append(stacked_encoder_features) + # Stack the first intermediate features for all views + stacked_intermediate_features_1 = torch.cat( + intermediate_info_sharing_multi_view_feat[0].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_1) + # Stack the second intermediate features for all views + stacked_intermediate_features_2 = torch.cat( + intermediate_info_sharing_multi_view_feat[1].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_2) + # Stack the last layer features for all views + stacked_final_features = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + dense_head_inputs_list.append(stacked_final_features) + else: + # Stack the first intermediate features for all views + stacked_intermediate_features_1 = torch.cat( + intermediate_info_sharing_multi_view_feat[0].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_1) + # Stack the second intermediate features for all views + stacked_intermediate_features_2 = torch.cat( + intermediate_info_sharing_multi_view_feat[1].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_2) + # Stack the third intermediate features for all views + stacked_intermediate_features_3 = torch.cat( + intermediate_info_sharing_multi_view_feat[2].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_3) + # Stack the last layer + stacked_final_features = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + dense_head_inputs_list.append(stacked_final_features) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + # Downstream task prediction + with torch.autocast("cuda", enabled=False): + # Run Prediction Heads & Post-Process Outputs + if self.pred_head_type == "linear": + dense_head_outputs = self.dense_head( + PredictionHeadInput(last_feature=dense_head_inputs) + ) + dense_final_outputs = self.dense_adaptor( + AdaptorInput( + adaptor_feature=dense_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + elif self.pred_head_type == "dpt": + dense_head_outputs = self.dense_head( + PredictionHeadLayeredInput( + list_features=dense_head_inputs_list, + target_output_shape=img_shape, + ) + ) + dense_final_outputs = self.dense_adaptor( + AdaptorInput( + adaptor_feature=dense_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + elif self.pred_head_type == "dpt+pose": + dense_head_outputs = self.dense_head( + PredictionHeadLayeredInput( + list_features=dense_head_inputs_list, + target_output_shape=img_shape, + ) + ) + dense_final_outputs = self.dense_adaptor( + AdaptorInput( + adaptor_feature=dense_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + pose_head_outputs = self.pose_head( + PredictionHeadInput(last_feature=dense_head_inputs_list[-1]) + ) + pose_final_outputs = self.pose_adaptor( + AdaptorInput( + adaptor_feature=pose_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + # Prepare the final scene representation for all views + if self.scene_rep_type in [ + "pointmap", + "pointmap+confidence", + "pointmap+mask", + "pointmap+confidence+mask", + ]: + output_pts3d = dense_final_outputs.value + # Reshape final scene representation to (B * V, H, W, C) + output_pts3d = output_pts3d.permute(0, 2, 3, 1).contiguous() + # Split the predicted pointmaps back to their respective views + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append({"pts3d": output_pts3d_per_view[i]}) + elif self.scene_rep_type in [ + "raymap+depth", + "raymap+depth+confidence", + "raymap+depth+mask", + "raymap+depth+confidence+mask", + ]: + # Reshape final scene representation to (B * V, H, W, C) + output_scene_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted ray origins, directions, and depths along rays + output_ray_origins, output_ray_directions, output_depth_along_ray = ( + output_scene_rep.split([3, 3, 1], dim=-1) + ) + # Get the predicted pointmaps + output_pts3d = ( + output_ray_origins + output_ray_directions * output_depth_along_ray + ) + # Split the predicted quantities back to their respective views + output_ray_origins_per_view = output_ray_origins.chunk(num_views, dim=0) + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i], + "ray_origins": output_ray_origins_per_view[i], + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i], + } + ) + elif self.scene_rep_type in [ + "raydirs+depth+pose", + "raydirs+depth+pose+confidence", + "raydirs+depth+pose+mask", + "raydirs+depth+pose+confidence+mask", + ]: + # Reshape output dense rep to (B * V, H, W, C) + output_dense_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted ray directions and depths along rays + output_ray_directions, output_depth_along_ray = output_dense_rep.split( + [3, 1], dim=-1 + ) + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the predicted pointmaps in world frame and camera frame + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + output_pts3d_cam = output_ray_directions * output_depth_along_ray + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i], + "pts3d_cam": output_pts3d_cam_per_view[i], + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i], + "cam_trans": output_cam_translations_per_view[i], + "cam_quats": output_cam_quats_per_view[i], + } + ) + elif self.scene_rep_type in [ + "campointmap+pose", + "campointmap+pose+confidence", + "campointmap+pose+mask", + "campointmap+pose+confidence+mask", + ]: + # Get the predicted camera frame pointmaps + output_pts3d_cam = dense_final_outputs.value + # Reshape final scene representation to (B * V, H, W, C) + output_pts3d_cam = output_pts3d_cam.permute(0, 2, 3, 1).contiguous() + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the ray directions and depths along rays + output_depth_along_ray = torch.norm( + output_pts3d_cam, dim=-1, keepdim=True + ) + output_ray_directions = output_pts3d_cam / output_depth_along_ray + # Get the predicted pointmaps in world frame + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i], + "pts3d_cam": output_pts3d_cam_per_view[i], + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i], + "cam_trans": output_cam_translations_per_view[i], + "cam_quats": output_cam_quats_per_view[i], + } + ) + elif self.scene_rep_type in [ + "pointmap+raydirs+depth+pose", + "pointmap+raydirs+depth+pose+confidence", + "pointmap+raydirs+depth+pose+mask", + "pointmap+raydirs+depth+pose+confidence+mask", + ]: + # Reshape final scene representation to (B * V, H, W, C) + output_dense_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted pointmaps, ray directions and depths along rays + output_pts3d, output_ray_directions, output_depth_along_ray = ( + output_dense_rep.split([3, 3, 1], dim=-1) + ) + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the predicted pointmaps in camera frame + output_pts3d_cam = output_ray_directions * output_depth_along_ray + # Replace the predicted world-frame pointmaps if required + if self.pred_head_config["adaptor_config"][ + "use_factored_predictions_for_global_pointmaps" + ]: + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i], + "pts3d_cam": output_pts3d_cam_per_view[i], + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i], + "cam_trans": output_cam_translations_per_view[i], + "cam_quats": output_cam_quats_per_view[i], + } + ) + else: + raise ValueError( + f"Invalid scene_rep_type: {self.scene_rep_type}. \ + Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \ + 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \ + 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \ + 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']" + ) + + # Get the output confidences for all views (if available) and add them to the result + if "confidence" in self.scene_rep_type: + output_confidences = dense_final_outputs.confidence + # Reshape confidences to (B * V, H, W) + output_confidences = ( + output_confidences.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + # Split the predicted confidences back to their respective views + output_confidences_per_view = output_confidences.chunk(num_views, dim=0) + # Add the confidences to the result + for i in range(num_views): + res[i]["conf"] = output_confidences_per_view[i] + + # Get the output masks (and logits) for all views (if available) and add them to the result + if "mask" in self.scene_rep_type: + # Get the output masks + output_masks = dense_final_outputs.mask + # Reshape masks to (B * V, H, W) + output_masks = output_masks.permute(0, 2, 3, 1).squeeze(-1).contiguous() + # Threshold the masks at 0.5 to get binary masks (0: ambiguous/invalid, 1: non-ambiguous/valid) + output_masks = output_masks > 0.5 + # Split the predicted masks back to their respective views + output_masks_per_view = output_masks.chunk(num_views, dim=0) + # Get the output mask logits (for loss) + output_mask_logits = dense_final_outputs.logits + # Reshape mask logits to (B * V, H, W) + output_mask_logits = ( + output_mask_logits.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + # Split the predicted mask logits back to their respective views + output_mask_logits_per_view = output_mask_logits.chunk(num_views, dim=0) + # Add the masks and logits to the result + for i in range(num_views): + res[i]["non_ambiguous_mask"] = output_masks_per_view[i] + res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i] + + return res diff --git a/mapanything/models/mapanything/model.py b/mapanything/models/mapanything/model.py new file mode 100644 index 0000000000000000000000000000000000000000..696038ec50f1b107c3ebb4b91cba26662514cffe --- /dev/null +++ b/mapanything/models/mapanything/model.py @@ -0,0 +1,2112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MapAnything model class defined using UniCeption modules. +""" + +import warnings +from functools import partial +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin + +from mapanything.utils.geometry import ( + apply_log_to_norm, + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + normalize_depth_using_non_zero_pixels, + normalize_pose_translations, + transform_pose_using_quats_and_trans_2_to_1, +) +from mapanything.utils.inference import ( + postprocess_model_outputs_for_inference, + preprocess_input_views_for_inference, + validate_input_views_for_inference, +) +from uniception.models.encoders import ( + encoder_factory, + EncoderGlobalRepInput, + ViTEncoderInput, + ViTEncoderNonImageInput, +) +from uniception.models.info_sharing.alternating_attention_transformer import ( + MultiViewAlternatingAttentionTransformer, + MultiViewAlternatingAttentionTransformerIFR, +) +from uniception.models.info_sharing.base import MultiViewTransformerInput +from uniception.models.info_sharing.cross_attention_transformer import ( + MultiViewCrossAttentionTransformer, + MultiViewCrossAttentionTransformerIFR, +) +from uniception.models.info_sharing.global_attention_transformer import ( + MultiViewGlobalAttentionTransformer, + MultiViewGlobalAttentionTransformerIFR, +) +from uniception.models.prediction_heads.adaptors import ( + CamTranslationPlusQuatsAdaptor, + PointMapAdaptor, + PointMapPlusRayDirectionsPlusDepthAdaptor, + PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor, + PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor, + PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor, + PointMapWithConfidenceAdaptor, + PointMapWithConfidenceAndMaskAdaptor, + PointMapWithMaskAdaptor, + RayDirectionsPlusDepthAdaptor, + RayDirectionsPlusDepthWithConfidenceAdaptor, + RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor, + RayDirectionsPlusDepthWithMaskAdaptor, + RayMapPlusDepthAdaptor, + RayMapPlusDepthWithConfidenceAdaptor, + RayMapPlusDepthWithConfidenceAndMaskAdaptor, + RayMapPlusDepthWithMaskAdaptor, + ScaleAdaptor, +) +from uniception.models.prediction_heads.base import ( + AdaptorInput, + PredictionHeadInput, + PredictionHeadLayeredInput, + PredictionHeadTokenInput, +) +from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor +from uniception.models.prediction_heads.linear import LinearFeature +from uniception.models.prediction_heads.mlp_head import MLPHead +from uniception.models.prediction_heads.pose_head import PoseHead + +# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) +if hasattr(torch.backends.cuda, "matmul") and hasattr( + torch.backends.cuda.matmul, "allow_tf32" +): + torch.backends.cuda.matmul.allow_tf32 = True + + +class MapAnything(nn.Module, PyTorchModelHubMixin): + "Modular MapAnything model class that supports input of images & optional geometric modalities (multiple reconstruction tasks)." + + def __init__( + self, + name: str, + encoder_config: Dict, + info_sharing_config: Dict, + pred_head_config: Dict, + geometric_input_config: Dict, + fusion_norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial( + nn.LayerNorm, eps=1e-6 + ), + pretrained_checkpoint_path: str = None, + load_specific_pretrained_submodules: bool = False, + specific_pretrained_submodules: list = None, + torch_hub_force_reload: bool = False, + ): + """ + Multi-view model containing an image encoder fused with optional geometric modalities followed by a multi-view attention transformer and respective downstream heads. + The goal is to output scene representation. + The multi-view attention transformer also takes as input a scale token to predict the metric scaling factor for the predicted scene representation. + + Args: + name (str): Name of the model. + encoder_config (Dict): Configuration for the encoder. + info_sharing_config (Dict): Configuration for the multi-view attention transformer. + pred_head_config (Dict): Configuration for the prediction heads. + geometric_input_config (Dict): Configuration for the input of optional geometric modalities. + fusion_norm_layer (Union[Type[nn.Module], Callable[..., nn.Module]]): Normalization layer to use after fusion (addition) of encoder and geometric modalities. (default: partial(nn.LayerNorm, eps=1e-6)) + pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None) + load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False) + specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: None) + torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False) + """ + super().__init__() + + # Initialize the attributes + self.name = name + self.encoder_config = encoder_config + self.info_sharing_config = info_sharing_config + self.pred_head_config = pred_head_config + self.geometric_input_config = geometric_input_config + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.load_specific_pretrained_submodules = load_specific_pretrained_submodules + self.specific_pretrained_submodules = specific_pretrained_submodules + self.torch_hub_force_reload = torch_hub_force_reload + self.class_init_args = { + "name": self.name, + "encoder_config": self.encoder_config, + "info_sharing_config": self.info_sharing_config, + "pred_head_config": self.pred_head_config, + "geometric_input_config": self.geometric_input_config, + "pretrained_checkpoint_path": self.pretrained_checkpoint_path, + "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules, + "specific_pretrained_submodules": self.specific_pretrained_submodules, + "torch_hub_force_reload": self.torch_hub_force_reload, + } + + # Get relevant parameters from the configs + self.info_sharing_type = info_sharing_config["model_type"] + self.info_sharing_return_type = info_sharing_config["model_return_type"] + self.pred_head_type = pred_head_config["type"] + + # Initialize image encoder + if self.encoder_config["uses_torch_hub"]: + self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload + # Create a copy of the config before deleting the key to preserve it for serialization + encoder_config_copy = self.encoder_config.copy() + del encoder_config_copy["uses_torch_hub"] + self.encoder = encoder_factory(**encoder_config_copy) + + # Initialize the encoder for ray directions + ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"] + ray_dirs_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + ray_dirs_encoder_config["patch_size"] = self.encoder.patch_size + self.ray_dirs_encoder = encoder_factory(**ray_dirs_encoder_config) + + # Initialize the encoder for depth (normalized per view and values after normalization are scaled logarithmically) + depth_encoder_config = self.geometric_input_config["depth_encoder_config"] + depth_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + depth_encoder_config["patch_size"] = self.encoder.patch_size + self.depth_encoder = encoder_factory(**depth_encoder_config) + + # Initialize the encoder for log scale factor of depth + depth_scale_encoder_config = self.geometric_input_config["scale_encoder_config"] + depth_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.depth_scale_encoder = encoder_factory(**depth_scale_encoder_config) + + # Initialize the encoder for camera rotation + cam_rot_encoder_config = self.geometric_input_config["cam_rot_encoder_config"] + cam_rot_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_rot_encoder = encoder_factory(**cam_rot_encoder_config) + + # Initialize the encoder for camera translation (normalized across all provided camera translations) + cam_trans_encoder_config = self.geometric_input_config[ + "cam_trans_encoder_config" + ] + cam_trans_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_trans_encoder = encoder_factory(**cam_trans_encoder_config) + + # Initialize the encoder for log scale factor of camera translation + cam_trans_scale_encoder_config = self.geometric_input_config[ + "scale_encoder_config" + ] + cam_trans_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_trans_scale_encoder = encoder_factory(**cam_trans_scale_encoder_config) + + # Initialize the fusion norm layer + self.fusion_norm_layer = fusion_norm_layer(self.encoder.enc_embed_dim) + + # Initialize the Scale Token + # Used to scale the final scene predictions to metric scale + # During inference extended to (B, C, T), where T is the number of tokens (i.e., 1) + self.scale_token = nn.Parameter(torch.zeros(self.encoder.enc_embed_dim)) + torch.nn.init.trunc_normal_(self.scale_token, std=0.02) + + # Initialize the info sharing module (multi-view transformer) + self._initialize_info_sharing(info_sharing_config) + + # Initialize the prediction heads + self._initialize_prediction_heads(pred_head_config) + + # Initialize the final adaptors + self._initialize_adaptors(pred_head_config) + + # Load pretrained weights + self._load_pretrained_weights() + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def _initialize_info_sharing(self, info_sharing_config): + """ + Initialize the information sharing module based on the configuration. + + This method sets up the custom positional encoding if specified and initializes + the appropriate multi-view transformer based on the configuration type. + + Args: + info_sharing_config (Dict): Configuration for the multi-view attention transformer. + Should contain 'custom_positional_encoding', 'model_type', and 'model_return_type'. + + Returns: + None + + Raises: + ValueError: If invalid configuration options are provided. + """ + # Initialize Custom Positional Encoding if required + custom_positional_encoding = info_sharing_config["custom_positional_encoding"] + if custom_positional_encoding is not None: + if isinstance(custom_positional_encoding, str): + print( + f"Using custom positional encoding for multi-view attention transformer: {custom_positional_encoding}" + ) + raise ValueError( + f"Invalid custom_positional_encoding: {custom_positional_encoding}. None implemented." + ) + elif isinstance(custom_positional_encoding, Callable): + print( + "Using callable function as custom positional encoding for multi-view attention transformer." + ) + self.custom_positional_encoding = custom_positional_encoding + else: + self.custom_positional_encoding = None + + # Add dependencies to info_sharing_config + info_sharing_config["module_args"]["input_embed_dim"] = ( + self.encoder.enc_embed_dim + ) + info_sharing_config["module_args"]["custom_positional_encoding"] = ( + self.custom_positional_encoding + ) + + # Initialize Multi-View Transformer + if self.info_sharing_return_type == "no_intermediate_features": + # Returns only normalized last layer features + # Initialize multi-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformer( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + elif self.info_sharing_return_type == "intermediate_features": + # Returns intermediate features and normalized last layer features + # Initialize mulit-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + # Assess if the DPT needs to use encoder features + if len(self.info_sharing.indices) == 2: + self.use_encoder_features_for_dpt = True + elif len(self.info_sharing.indices) == 3: + self.use_encoder_features_for_dpt = False + else: + raise ValueError( + "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices." + ) + else: + raise ValueError( + f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']" + ) + + def _initialize_prediction_heads(self, pred_head_config): + """ + Initialize the prediction heads based on the prediction head configuration. + + This method configures and initializes the appropriate prediction heads based on the + specified prediction head type (linear, DPT, or DPT+pose). It sets up the necessary + dependencies and creates the required model components. + + Args: + pred_head_config (Dict): Configuration for the prediction heads. + + Returns: + None + + Raises: + ValueError: If an invalid pred_head_type is provided. + """ + # Add dependencies to prediction head config + pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size + if self.pred_head_type == "linear": + pred_head_config["feature_head"]["input_feature_dim"] = ( + self.info_sharing.dim + ) + elif "dpt" in self.pred_head_type: + # Add dependencies for DPT & Regressor head + if self.use_encoder_features_for_dpt: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.encoder.enc_embed_dim + ] + [self.info_sharing.dim] * 3 + else: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.info_sharing.dim + ] * 4 + pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[ + "feature_head" + ]["feature_dim"] + # Add dependencies for Pose head if required + if "pose" in self.pred_head_type: + pred_head_config["pose_head"]["patch_size"] = self.encoder.patch_size + pred_head_config["pose_head"]["input_feature_dim"] = ( + self.info_sharing.dim + ) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + pred_head_config["scale_head"]["input_feature_dim"] = self.info_sharing.dim + + # Initialize Prediction Heads + if self.pred_head_type == "linear": + # Initialize Dense Prediction Head for all views + self.dense_head = LinearFeature(**pred_head_config["feature_head"]) + elif "dpt" in self.pred_head_type: + # Initialize Dense Prediction Head for all views + self.dpt_feature_head = DPTFeature(**pred_head_config["feature_head"]) + self.dpt_regressor_head = DPTRegressionProcessor( + **pred_head_config["regressor_head"] + ) + self.dense_head = nn.Sequential( + self.dpt_feature_head, self.dpt_regressor_head + ) + # Initialize Pose Head for all views if required + if "pose" in self.pred_head_type: + self.pose_head = PoseHead(**pred_head_config["pose_head"]) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + self.scale_head = MLPHead(**pred_head_config["scale_head"]) + + def _initialize_adaptors(self, pred_head_config): + """ + Initialize the adaptors based on the prediction head configuration. + + This method sets up the appropriate adaptors for different scene representation types, + such as pointmaps, ray maps with depth, or ray directions with depth and pose. + + Args: + pred_head_config (Dict): Configuration for the prediction heads including adaptor type. + + Returns: + None + + Raises: + ValueError: If an invalid adaptor_type is provided. + AssertionError: If ray directions + depth + pose is used with an incompatible head type. + """ + if pred_head_config["adaptor_type"] == "pointmap": + self.dense_adaptor = PointMapAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "pointmap" + elif pred_head_config["adaptor_type"] == "pointmap+confidence": + self.dense_adaptor = PointMapWithConfidenceAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "pointmap+confidence" + elif pred_head_config["adaptor_type"] == "pointmap+mask": + self.dense_adaptor = PointMapWithMaskAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "pointmap+mask" + elif pred_head_config["adaptor_type"] == "pointmap+confidence+mask": + self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "pointmap+confidence+mask" + elif pred_head_config["adaptor_type"] == "raymap+depth": + self.dense_adaptor = RayMapPlusDepthAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "raymap+depth" + elif pred_head_config["adaptor_type"] == "raymap+depth+confidence": + self.dense_adaptor = RayMapPlusDepthWithConfidenceAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+confidence" + elif pred_head_config["adaptor_type"] == "raymap+depth+mask": + self.dense_adaptor = RayMapPlusDepthWithMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+mask" + elif pred_head_config["adaptor_type"] == "raymap+depth+confidence+mask": + self.dense_adaptor = RayMapPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+confidence+mask" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+confidence" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+mask" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence+mask": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+confidence+mask" + elif pred_head_config["adaptor_type"] == "campointmap+pose": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapAdaptor(**pred_head_config["dpt_adaptor"]) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose" + elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+confidence" + elif pred_head_config["adaptor_type"] == "campointmap+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+mask" + elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence+mask": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+confidence+mask" + elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose": + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose" + elif ( + pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+confidence" + ): + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = ( + PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence" + elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+mask" + elif ( + pred_head_config["adaptor_type"] + == "pointmap+raydirs+depth+pose+confidence+mask" + ): + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = ( + PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence+mask" + else: + raise ValueError( + f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. \ + Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \ + 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \ + 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \ + 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']" + ) + self.scale_adaptor = ScaleAdaptor(**pred_head_config["scale_adaptor"]) + + def _load_pretrained_weights(self): + """ + Load pretrained weights from a checkpoint file. + + If load_specific_pretrained_submodules is True, only loads weights for the specified submodules. + Otherwise, loads all weights from the checkpoint. + + Returns: + None + """ + if self.pretrained_checkpoint_path is not None: + if not self.load_specific_pretrained_submodules: + print( + f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + else: + print( + f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} for specific submodules: {self.specific_pretrained_submodules} ..." + ) + assert self.pred_head_type is not None, ( + "Specific submodules to load cannot be None." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + filtered_ckpt = {} + for ckpt_key, ckpt_value in ckpt["model"].items(): + for submodule in self.specific_pretrained_submodules: + if ckpt_key.startswith(submodule): + filtered_ckpt[ckpt_key] = ckpt_value + print(self.load_state_dict(filtered_ckpt, strict=False)) + + def _encode_n_views(self, views): + """ + Encode all the input views (batch of images) in a single forward pass. + Assumes all the input views have the same image shape, batch size, and data normalization type. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + + Returns: + List[torch.Tensor]: A list containing the encoded features for all N views. + """ + num_views = len(views) + data_norm_type = views[0]["data_norm_type"][0] + imgs_list = [view["img"] for view in views] + all_imgs_across_views = torch.cat(imgs_list, dim=0) + encoder_input = ViTEncoderInput( + image=all_imgs_across_views, data_norm_type=data_norm_type + ) + encoder_output = self.encoder(encoder_input) + all_encoder_features_across_views = encoder_output.features.chunk( + num_views, dim=0 + ) + + return all_encoder_features_across_views + + def _compute_pose_quats_and_trans_for_across_views_in_ref_view( + self, + views, + num_views, + device, + dtype, + batch_size_per_view, + per_sample_cam_input_mask, + ): + """ + Compute the pose quats and trans for all the views in the frame of the reference view 0. + Returns identity pose for views where the camera input mask is False or the pose is not provided. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + device (torch.device): Device to use for the computation. + dtype (torch.dtype): Data type to use for the computation. + per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask. + + Returns: + torch.Tensor: A tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4) + torch.Tensor: A tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3) + torch.Tensor: A tensor containing the per sample camera input mask. + """ + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + pose_quats_non_ref_views = [] + pose_trans_non_ref_views = [] + pose_quats_ref_view_0 = [] + pose_trans_ref_view_0 = [] + for view_idx in range(num_views): + per_sample_cam_input_mask_for_curr_view = per_sample_cam_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] + if ( + "camera_pose_quats" in views[view_idx] + and "camera_pose_trans" in views[view_idx] + and per_sample_cam_input_mask_for_curr_view.any() + ): + # Get the camera pose quats and trans for the current view + cam_pose_quats = views[view_idx]["camera_pose_quats"][ + per_sample_cam_input_mask_for_curr_view + ] + cam_pose_trans = views[view_idx]["camera_pose_trans"][ + per_sample_cam_input_mask_for_curr_view + ] + # Append to the list + pose_quats_non_ref_views.append(cam_pose_quats) + pose_trans_non_ref_views.append(cam_pose_trans) + # Get the camera pose quats and trans for the reference view 0 + cam_pose_quats = views[0]["camera_pose_quats"][ + per_sample_cam_input_mask_for_curr_view + ] + cam_pose_trans = views[0]["camera_pose_trans"][ + per_sample_cam_input_mask_for_curr_view + ] + # Append to the list + pose_quats_ref_view_0.append(cam_pose_quats) + pose_trans_ref_view_0.append(cam_pose_trans) + else: + per_sample_cam_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + + # Initialize the pose quats and trans for all views as identity + pose_quats_across_views = torch.tensor( + [0.0, 0.0, 0.0, 1.0], dtype=dtype, device=device + ).repeat(batch_size_per_view * num_views, 1) # (q_x, q_y, q_z, q_w) + pose_trans_across_views = torch.zeros( + (batch_size_per_view * num_views, 3), dtype=dtype, device=device + ) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + if len(pose_quats_non_ref_views) > 0: + # Stack the pose quats and trans for all the non-reference views and reference view 0 + pose_quats_non_ref_views = torch.cat(pose_quats_non_ref_views, dim=0) + pose_trans_non_ref_views = torch.cat(pose_trans_non_ref_views, dim=0) + pose_quats_ref_view_0 = torch.cat(pose_quats_ref_view_0, dim=0) + pose_trans_ref_view_0 = torch.cat(pose_trans_ref_view_0, dim=0) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + ( + pose_quats_non_ref_views_in_ref_view_0, + pose_trans_non_ref_views_in_ref_view_0, + ) = transform_pose_using_quats_and_trans_2_to_1( + pose_quats_ref_view_0, + pose_trans_ref_view_0, + pose_quats_non_ref_views, + pose_trans_non_ref_views, + ) + + # Update the pose quats and trans for all the non-reference views + pose_quats_across_views[per_sample_cam_input_mask] = ( + pose_quats_non_ref_views_in_ref_view_0.to(dtype=dtype) + ) + pose_trans_across_views[per_sample_cam_input_mask] = ( + pose_trans_non_ref_views_in_ref_view_0.to(dtype=dtype) + ) + + return ( + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ) + + def _encode_and_fuse_ray_dirs( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_ray_dirs_input_mask, + ): + """ + Encode the ray directions for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + per_sample_ray_dirs_input_mask (torch.Tensor): Tensor containing the per sample ray direction input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Get the height and width of the images + _, _, height, width = views[0]["img"].shape + + # Get the ray directions for all the views where info is provided and the ray direction input mask is True + ray_dirs_list = [] + for view_idx in range(num_views): + per_sample_ray_dirs_input_mask_for_curr_view = ( + per_sample_ray_dirs_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] + ) + ray_dirs_for_curr_view = torch.zeros( + (batch_size_per_view, height, width, 3), + dtype=all_encoder_features_across_views.dtype, + device=all_encoder_features_across_views.device, + ) + if ( + "ray_directions_cam" in views[view_idx] + and per_sample_ray_dirs_input_mask_for_curr_view.any() + ): + ray_dirs_for_curr_view[per_sample_ray_dirs_input_mask_for_curr_view] = ( + views[view_idx]["ray_directions_cam"][ + per_sample_ray_dirs_input_mask_for_curr_view + ] + ) + else: + per_sample_ray_dirs_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + ray_dirs_list.append(ray_dirs_for_curr_view) + + # Stack the ray directions for all the views and permute to (B * V, C, H, W) + ray_dirs = torch.cat(ray_dirs_list, dim=0) # (B * V, H, W, 3) + ray_dirs = ray_dirs.permute(0, 3, 1, 2).contiguous() # (B * V, 3, H, W) + + # Encode the ray directions + ray_dirs_features_across_views = self.ray_dirs_encoder( + ViTEncoderNonImageInput(data=ray_dirs) + ).features + + # Fuse the ray direction features with the other encoder features (zero out the features where the ray direction input mask is False) + ray_dirs_features_across_views = ( + ray_dirs_features_across_views + * per_sample_ray_dirs_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + ) + all_encoder_features_across_views = ( + all_encoder_features_across_views + ray_dirs_features_across_views + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_depths( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_depth_input_mask, + ): + """ + Encode the z depths for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + per_sample_depth_input_mask (torch.Tensor): Tensor containing the per sample depth input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Get the device and height and width of the images + device = all_encoder_features_across_views.device + _, _, height, width = views[0]["img"].shape + + # Decide to use randomly sampled sparse depth or dense depth + if torch.rand(1) < self.geometric_input_config["sparse_depth_prob"]: + use_sparse_depth = True + else: + use_sparse_depth = False + + # Get the depths for all the views + depth_list = [] + depth_norm_factors_list = [] + metric_scale_depth_mask_list = [] + for view_idx in range(num_views): + # Get the input mask for current view + per_sample_depth_input_mask_for_curr_view = per_sample_depth_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] + depth_for_curr_view = torch.zeros( + (batch_size_per_view, height, width, 1), + dtype=all_encoder_features_across_views.dtype, + device=device, + ) + depth_norm_factor_for_curr_view = torch.zeros( + (batch_size_per_view), + dtype=all_encoder_features_across_views.dtype, + device=device, + ) + metric_scale_mask_for_curr_view = torch.zeros( + (batch_size_per_view), + dtype=torch.bool, + device=device, + ) + if ( + "depth_along_ray" in views[view_idx] + ) and per_sample_depth_input_mask_for_curr_view.any(): + # Get depth for current view + depth_for_curr_view_input = views[view_idx]["depth_along_ray"][ + per_sample_depth_input_mask_for_curr_view + ] + # Get the metric scale mask + if "is_metric_scale" in views[view_idx]: + metric_scale_mask = views[view_idx]["is_metric_scale"][ + per_sample_depth_input_mask_for_curr_view + ] + else: + metric_scale_mask = torch.zeros( + depth_for_curr_view_input.shape[0], + dtype=torch.bool, + device=device, + ) + # Turn off indication of metric scale samples based on the depth_scale_norm_all_prob + depth_scale_norm_all_mask = ( + torch.rand(metric_scale_mask.shape[0]) + < self.geometric_input_config["depth_scale_norm_all_prob"] + ) + if depth_scale_norm_all_mask.any(): + metric_scale_mask[depth_scale_norm_all_mask] = False + # Assign the metric scale mask to the respective indices + metric_scale_mask_for_curr_view[ + per_sample_depth_input_mask_for_curr_view + ] = metric_scale_mask + # Sparsely sample the depth if required + if use_sparse_depth: + # Create a mask of ones + sparsification_mask = torch.ones_like( + depth_for_curr_view_input, device=device + ) + # Create a mask for valid pixels (depth > 0) + valid_pixel_mask = depth_for_curr_view_input > 0 + # Calculate the number of valid pixels + num_valid_pixels = valid_pixel_mask.sum().item() + # Calculate the number of valid pixels to set to zero + num_to_zero = int( + num_valid_pixels + * self.geometric_input_config["sparsification_removal_percent"] + ) + if num_to_zero > 0: + # Get the indices of valid pixels + valid_indices = valid_pixel_mask.nonzero(as_tuple=True) + # Randomly select indices to zero out + indices_to_zero = torch.randperm(num_valid_pixels)[:num_to_zero] + # Set selected valid indices to zero in the mask + sparsification_mask[ + valid_indices[0][indices_to_zero], + valid_indices[1][indices_to_zero], + valid_indices[2][indices_to_zero], + valid_indices[3][indices_to_zero], + ] = 0 + # Apply the mask on the depth + depth_for_curr_view_input = ( + depth_for_curr_view_input * sparsification_mask + ) + # Normalize the depth + scaled_depth_for_curr_view_input, depth_norm_factor = ( + normalize_depth_using_non_zero_pixels( + depth_for_curr_view_input, return_norm_factor=True + ) + ) + # Assign the depth and depth norm factor to the respective indices + depth_for_curr_view[per_sample_depth_input_mask_for_curr_view] = ( + scaled_depth_for_curr_view_input + ) + depth_norm_factor_for_curr_view[ + per_sample_depth_input_mask_for_curr_view + ] = depth_norm_factor + else: + per_sample_depth_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + # Append the depths, depth norm factor and metric scale mask for the current view + depth_list.append(depth_for_curr_view) + depth_norm_factors_list.append(depth_norm_factor_for_curr_view) + metric_scale_depth_mask_list.append(metric_scale_mask_for_curr_view) + + # Stack the depths for all the views and permute to (B * V, C, H, W) + depths = torch.cat(depth_list, dim=0) # (B * V, H, W, 1) + depths = apply_log_to_norm( + depths + ) # Scale logarithimically (norm is computed along last dim) + depths = depths.permute(0, 3, 1, 2).contiguous() # (B * V, 1, H, W) + # Encode the depths using the depth encoder + depth_features_across_views = self.depth_encoder( + ViTEncoderNonImageInput(data=depths) + ).features + # Zero out the depth features where the depth input mask is False + depth_features_across_views = ( + depth_features_across_views + * per_sample_depth_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + ) + + # Stack the depth norm factors for all the views + depth_norm_factors = torch.cat(depth_norm_factors_list, dim=0) # (B * V, ) + # Encode the depth norm factors using the log scale encoder for depth + log_depth_norm_factors = torch.log(depth_norm_factors + 1e-8) # (B * V, ) + depth_scale_features_across_views = self.depth_scale_encoder( + EncoderGlobalRepInput(data=log_depth_norm_factors.unsqueeze(-1)) + ).features + # Zero out the depth scale features where the depth input mask is False + depth_scale_features_across_views = ( + depth_scale_features_across_views + * per_sample_depth_input_mask.unsqueeze(-1) + ) + # Stack the metric scale mask for all the views + metric_scale_depth_mask = torch.cat( + metric_scale_depth_mask_list, dim=0 + ) # (B * V, ) + # Zero out the depth scale features where the metric scale mask is False + # Scale encoding is only provided for metric scale samples + depth_scale_features_across_views = ( + depth_scale_features_across_views * metric_scale_depth_mask.unsqueeze(-1) + ) + + # Fuse the depth features & depth scale features with the other encoder features + all_encoder_features_across_views = ( + all_encoder_features_across_views + + depth_features_across_views + + depth_scale_features_across_views.unsqueeze(-1).unsqueeze(-1) + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_cam_quats_and_trans( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ): + """ + Encode the camera quats and trans for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + pose_quats_across_views (torch.Tensor): Tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4) + pose_trans_across_views (torch.Tensor): Tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3) + per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Encode the pose quats + pose_quats_features_across_views = self.cam_rot_encoder( + EncoderGlobalRepInput(data=pose_quats_across_views) + ).features + # Zero out the pose quat features where the camera input mask is False + pose_quats_features_across_views = ( + pose_quats_features_across_views * per_sample_cam_input_mask.unsqueeze(-1) + ) + + # Get the metric scale mask for all samples + device = all_encoder_features_across_views.device + metric_scale_pose_trans_mask = torch.zeros( + (batch_size_per_view * num_views), dtype=torch.bool, device=device + ) + for view_idx in range(num_views): + if "is_metric_scale" in views[view_idx]: + # Get the metric scale mask for the input pose priors + metric_scale_mask = views[view_idx]["is_metric_scale"] + else: + metric_scale_mask = torch.zeros( + batch_size_per_view, dtype=torch.bool, device=device + ) + metric_scale_pose_trans_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] = metric_scale_mask + + # Turn off indication of metric scale samples based on the pose_scale_norm_all_prob + pose_norm_all_mask = ( + torch.rand(batch_size_per_view * num_views) + < self.geometric_input_config["pose_scale_norm_all_prob"] + ) + if pose_norm_all_mask.any(): + metric_scale_pose_trans_mask[pose_norm_all_mask] = False + + # Get the scale norm factor for all the samples and scale the pose translations + pose_trans_across_views = torch.split( + pose_trans_across_views, batch_size_per_view, dim=0 + ) # Split into num_views chunks + pose_trans_across_views = torch.stack( + pose_trans_across_views, dim=1 + ) # Stack the views along a new dimension (batch_size_per_view, num_views, 3) + scaled_pose_trans_across_views, pose_trans_norm_factors = ( + normalize_pose_translations( + pose_trans_across_views, return_norm_factor=True + ) + ) + + # Resize the pose translation back to (batch_size_per_view * num_views, 3) and extend the norm factor to (batch_size_per_view * num_views, 1) + scaled_pose_trans_across_views = scaled_pose_trans_across_views.unbind( + dim=1 + ) # Convert back to list of views, where each view has batch_size_per_view tensor + scaled_pose_trans_across_views = torch.cat( + scaled_pose_trans_across_views, dim=0 + ) # Concatenate back to (batch_size_per_view * num_views, 3) + pose_trans_norm_factors_across_views = pose_trans_norm_factors.unsqueeze( + -1 + ).repeat(num_views, 1) # (B, ) -> (B * V, 1) + + # Encode the pose trans + pose_trans_features_across_views = self.cam_trans_encoder( + EncoderGlobalRepInput(data=scaled_pose_trans_across_views) + ).features + # Zero out the pose trans features where the camera input mask is False + pose_trans_features_across_views = ( + pose_trans_features_across_views * per_sample_cam_input_mask.unsqueeze(-1) + ) + + # Encode the pose translation norm factors using the log scale encoder for pose trans + log_pose_trans_norm_factors_across_views = torch.log( + pose_trans_norm_factors_across_views + 1e-8 + ) + pose_trans_scale_features_across_views = self.cam_trans_scale_encoder( + EncoderGlobalRepInput(data=log_pose_trans_norm_factors_across_views) + ).features + # Zero out the pose trans scale features where the camera input mask is False + pose_trans_scale_features_across_views = ( + pose_trans_scale_features_across_views + * per_sample_cam_input_mask.unsqueeze(-1) + ) + # Zero out the pose trans scale features where the metric scale mask is False + # Scale encoding is only provided for metric scale samples + pose_trans_scale_features_across_views = ( + pose_trans_scale_features_across_views + * metric_scale_pose_trans_mask.unsqueeze(-1) + ) + + # Fuse the pose quat features, pose trans features, pose trans scale features and pose trans type PE features with the other encoder features + all_encoder_features_across_views = ( + all_encoder_features_across_views + + pose_quats_features_across_views.unsqueeze(-1).unsqueeze(-1) + + pose_trans_features_across_views.unsqueeze(-1).unsqueeze(-1) + + pose_trans_scale_features_across_views.unsqueeze(-1).unsqueeze(-1) + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_optional_geometric_inputs( + self, views, all_encoder_features_across_views_list + ): + """ + Encode all the input optional geometric modalities and fuses it with the image encoder features in a single forward pass. + Assumes all the input views have the same shape and batch size. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + all_encoder_features_across_views (List[torch.Tensor]): List of tensors containing the encoded image features for all N views. + + Returns: + List[torch.Tensor]: A list containing the encoded features for all N views. + """ + num_views = len(views) + batch_size_per_view, _, _, _ = views[0]["img"].shape + device = all_encoder_features_across_views_list[0].device + dtype = all_encoder_features_across_views_list[0].dtype + all_encoder_features_across_views = torch.cat( + all_encoder_features_across_views_list, dim=0 + ) + + # Get the overall input mask for all the views + overall_geometric_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["overall_prob"] + ) + overall_geometric_input_mask = overall_geometric_input_mask.repeat(num_views) + + # Get the per sample input mask after dropout + # Per sample input mask is in view-major order so that index v*B + b in each mask corresponds to sample b of view v: (B * V) + per_sample_geometric_input_mask = torch.rand( + batch_size_per_view * num_views, device=device + ) < (1 - self.geometric_input_config["dropout_prob"]) + per_sample_geometric_input_mask = ( + per_sample_geometric_input_mask & overall_geometric_input_mask + ) + + # Get the ray direction input mask + per_sample_ray_dirs_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["ray_dirs_prob"] + ) + per_sample_ray_dirs_input_mask = per_sample_ray_dirs_input_mask.repeat( + num_views + ) + per_sample_ray_dirs_input_mask = ( + per_sample_ray_dirs_input_mask & per_sample_geometric_input_mask + ) + + # Get the depth input mask + per_sample_depth_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["depth_prob"] + ) + per_sample_depth_input_mask = per_sample_depth_input_mask.repeat(num_views) + per_sample_depth_input_mask = ( + per_sample_depth_input_mask & per_sample_geometric_input_mask + ) + + # Get the camera input mask + per_sample_cam_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["cam_prob"] + ) + per_sample_cam_input_mask = per_sample_cam_input_mask.repeat(num_views) + per_sample_cam_input_mask = ( + per_sample_cam_input_mask & per_sample_geometric_input_mask + ) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + # Returned pose quats and trans represent identity pose for views/samples where the camera input mask is False + pose_quats_across_views, pose_trans_across_views, per_sample_cam_input_mask = ( + self._compute_pose_quats_and_trans_for_across_views_in_ref_view( + views, + num_views, + device, + dtype, + batch_size_per_view, + per_sample_cam_input_mask, + ) + ) + + # Encode the ray directions and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_ray_dirs( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_ray_dirs_input_mask, + ) + + # Encode the depths and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_depths( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_depth_input_mask, + ) + + # Encode the cam quat and trans and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_cam_quats_and_trans( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ) + + # Normalize the fused features (permute -> normalize -> permute) + all_encoder_features_across_views = all_encoder_features_across_views.permute( + 0, 2, 3, 1 + ).contiguous() + all_encoder_features_across_views = self.fusion_norm_layer( + all_encoder_features_across_views + ) + all_encoder_features_across_views = all_encoder_features_across_views.permute( + 0, 3, 1, 2 + ).contiguous() + + # Split the batched views into individual views + fused_all_encoder_features_across_views = ( + all_encoder_features_across_views.chunk(num_views, dim=0) + ) + + return fused_all_encoder_features_across_views + + def _compute_adaptive_minibatch_size( + self, + memory_safety_factor: float = 0.95, + ) -> int: + """ + Compute adaptive minibatch size based on available PyTorch memory. + + Args: + memory_safety_factor: Safety factor to avoid OOM (0.95 = use 95% of available memory) + + Returns: + Computed minibatch size + """ + device = self.device + + if device.type == "cuda": + # Get available GPU memory + torch.cuda.empty_cache() + available_memory = torch.cuda.mem_get_info()[0] # Free memory in bytes + usable_memory = ( + available_memory * memory_safety_factor + ) # Use safety factor to avoid OOM + else: + # For non-CUDA devices, use conservative default + print( + "Non-CUDA device detected. Using conservative default minibatch size of 1 for memory efficient dense prediction head inference." + ) + return 1 + + # Determine minibatch size based on available memory + max_estimated_memory_per_sample = ( + 680 * 1024 * 1024 + ) # 680 MB per sample (upper bound profiling using a 518 x 518 input) + computed_minibatch_size = int(usable_memory / max_estimated_memory_per_sample) + if computed_minibatch_size < 1: + computed_minibatch_size = 1 + + return computed_minibatch_size + + def downstream_dense_head( + self, + dense_head_inputs: Union[torch.Tensor, List[torch.Tensor]], + img_shape: Tuple[int, int], + ): + """ + Run the downstream dense prediction head + """ + if self.pred_head_type == "linear": + dense_head_outputs = self.dense_head( + PredictionHeadInput(last_feature=dense_head_inputs) + ) + dense_final_outputs = self.dense_adaptor( + AdaptorInput( + adaptor_feature=dense_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + elif self.pred_head_type in ["dpt", "dpt+pose"]: + dense_head_outputs = self.dense_head( + PredictionHeadLayeredInput( + list_features=dense_head_inputs, + target_output_shape=img_shape, + ) + ) + dense_final_outputs = self.dense_adaptor( + AdaptorInput( + adaptor_feature=dense_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + return dense_final_outputs + + def downstream_head( + self, + dense_head_inputs: Union[torch.Tensor, List[torch.Tensor]], + scale_head_inputs: torch.Tensor, + img_shape: Tuple[int, int], + memory_efficient_inference: bool = False, + ): + """ + Run Prediction Heads & Post-Process Outputs + """ + # Get device + device = self.device + + # Use mini-batch inference to run the dense prediction head (the memory bottleneck) + # This saves memory and is slower than running the dense prediction head in one go + if memory_efficient_inference: + # Obtain the batch size of the dense head inputs + if self.pred_head_type == "linear": + batch_size = dense_head_inputs.shape[0] + elif self.pred_head_type in ["dpt", "dpt+pose"]: + batch_size = dense_head_inputs[0].shape[0] + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + # Compute the mini batch size and number of mini batches adaptively based on available memory + minibatch = self._compute_adaptive_minibatch_size() + num_batches = (batch_size + minibatch - 1) // minibatch + + # Run prediction for each mini-batch + dense_final_outputs_list = [] + pose_final_outputs_list = [] if self.pred_head_type == "dpt+pose" else None + for batch_idx in range(num_batches): + start_idx = batch_idx * minibatch + end_idx = min((batch_idx + 1) * minibatch, batch_size) + + # Get the inputs for the current mini-batch + if self.pred_head_type == "linear": + dense_head_inputs_batch = dense_head_inputs[start_idx:end_idx] + elif self.pred_head_type in ["dpt", "dpt+pose"]: + dense_head_inputs_batch = [ + x[start_idx:end_idx] for x in dense_head_inputs + ] + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + # Dense prediction (mini-batched) + dense_final_outputs_batch = self.downstream_dense_head( + dense_head_inputs_batch, img_shape + ) + dense_final_outputs_list.append(dense_final_outputs_batch) + + # Pose prediction (mini-batched) + if self.pred_head_type == "dpt+pose": + pose_head_inputs_batch = dense_head_inputs[-1][start_idx:end_idx] + pose_head_outputs_batch = self.pose_head( + PredictionHeadInput(last_feature=pose_head_inputs_batch) + ) + pose_final_outputs_batch = self.pose_adaptor( + AdaptorInput( + adaptor_feature=pose_head_outputs_batch.decoded_channels, + output_shape_hw=img_shape, + ) + ) + pose_final_outputs_list.append(pose_final_outputs_batch) + + # Concatenate the dense prediction head outputs from all mini-batches + available_keys = dense_final_outputs_batch.__dict__.keys() + dense_pred_data_dict = { + key: torch.cat( + [getattr(output, key) for output in dense_final_outputs_list], dim=0 + ) + for key in available_keys + } + dense_final_outputs = dense_final_outputs_batch.__class__( + **dense_pred_data_dict + ) + + # Concatenate the pose prediction head outputs from all mini-batches + pose_final_outputs = None + if self.pred_head_type == "dpt+pose": + available_keys = pose_final_outputs_batch.__dict__.keys() + pose_pred_data_dict = { + key: torch.cat( + [getattr(output, key) for output in pose_final_outputs_list], + dim=0, + ) + for key in available_keys + } + pose_final_outputs = pose_final_outputs_batch.__class__( + **pose_pred_data_dict + ) + + # Clear CUDA cache for better memory efficiency + if device.type == "cuda": + torch.cuda.empty_cache() + else: + # Run prediction for all (batch_size * num_views) in one go + # Dense prediction + dense_final_outputs = self.downstream_dense_head( + dense_head_inputs, img_shape + ) + + # Pose prediction + pose_final_outputs = None + if self.pred_head_type == "dpt+pose": + pose_head_outputs = self.pose_head( + PredictionHeadInput(last_feature=dense_head_inputs[-1]) + ) + pose_final_outputs = self.pose_adaptor( + AdaptorInput( + adaptor_feature=pose_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + + # Scale prediction is lightweight, so we can run it in one go + scale_head_output = self.scale_head( + PredictionHeadTokenInput(last_feature=scale_head_inputs) + ) + scale_final_output = self.scale_adaptor( + AdaptorInput( + adaptor_feature=scale_head_output.decoded_channels, + output_shape_hw=img_shape, + ) + ) + scale_final_output = scale_final_output.value.squeeze(-1) # (B, 1, 1) -> (B, 1) + + # Clear CUDA cache for better memory efficiency + if memory_efficient_inference and device.type == "cuda": + torch.cuda.empty_cache() + + return dense_final_outputs, pose_final_outputs, scale_final_output + + def forward(self, views, memory_efficient_inference=False): + """ + Forward pass performing the following operations: + 1. Encodes the N input views (images). + 2. Encodes the optional geometric inputs (ray directions, depths, camera rotations, camera translations). + 3. Fuses the encoded features from the N input views and the optional geometric inputs using addition and normalization. + 4. Information sharing across the encoded features and a scale token using a multi-view attention transformer. + 5. Passes the final features from transformer through the prediction heads. + 6. Returns the processed final outputs for N views. + + Assumption: + - All the input views and dense geometric inputs have the same image shape. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). Input images must be normalized based on the data norm type of image encoder. + "data_norm_type" (list): [model.encoder.data_norm_type] + Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs: + "ray_directions_cam" (tensor): Ray directions in the local camera frame. Tensor of shape (B, H, W, 3). + "depth_along_ray" (tensor): Depth along the ray. Tensor of shape (B, H, W, 1). + "camera_pose_quats" (tensor): Camera pose quaternions. Tensor of shape (B, 4). Camera pose is opencv (RDF) cam2world transformation. + "camera_pose_trans" (tensor): Camera pose translations. Tensor of shape (B, 3). Camera pose is opencv (RDF) cam2world transformation. + "is_metric_scale" (tensor): Boolean tensor indicating whether the geometric inputs are in metric scale or not. Tensor of shape (B, 1). + memory_efficient_inference (bool): Whether to use memory efficient inference or not. This runs the dense prediction head (the memory bottleneck) in a memory efficient manner. Default is False. + + Returns: + List[dict]: A list containing the final outputs for all N views. + """ + # Get input shape of the images, number of views, and batch size per view + batch_size_per_view, _, height, width = views[0]["img"].shape + img_shape = (int(height), int(width)) + num_views = len(views) + + # Run the image encoder on all the input views + all_encoder_features_across_views = self._encode_n_views(views) + + # Encode the optional geometric inputs and fuse with the encoded features from the N input views + # Use high precision to prevent NaN values after layer norm in dense representation encoder (due to high variance in last dim of features) + with torch.autocast("cuda", enabled=False): + all_encoder_features_across_views = ( + self._encode_and_fuse_optional_geometric_inputs( + views, all_encoder_features_across_views + ) + ) + + # Expand the scale token to match the batch size + input_scale_token = ( + self.scale_token.unsqueeze(0) + .unsqueeze(-1) + .repeat(batch_size_per_view, 1, 1) + ) # (B, C, 1) + + # Combine all images into view-centric representation + # Output is a list containing the encoded features for all N views after information sharing. + info_sharing_input = MultiViewTransformerInput( + features=all_encoder_features_across_views, + additional_input_tokens=input_scale_token, + ) + if self.info_sharing_return_type == "no_intermediate_features": + final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input) + elif self.info_sharing_return_type == "intermediate_features": + ( + final_info_sharing_multi_view_feat, + intermediate_info_sharing_multi_view_feat, + ) = self.info_sharing(info_sharing_input) + + if self.pred_head_type == "linear": + # Stack the features for all views + dense_head_inputs = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + elif self.pred_head_type in ["dpt", "dpt+pose"]: + # Get the list of features for all views + dense_head_inputs_list = [] + if self.use_encoder_features_for_dpt: + # Stack all the image encoder features for all views + stacked_encoder_features = torch.cat( + all_encoder_features_across_views, dim=0 + ) + dense_head_inputs_list.append(stacked_encoder_features) + # Stack the first intermediate features for all views + stacked_intermediate_features_1 = torch.cat( + intermediate_info_sharing_multi_view_feat[0].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_1) + # Stack the second intermediate features for all views + stacked_intermediate_features_2 = torch.cat( + intermediate_info_sharing_multi_view_feat[1].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_2) + # Stack the last layer features for all views + stacked_final_features = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + dense_head_inputs_list.append(stacked_final_features) + else: + # Stack the first intermediate features for all views + stacked_intermediate_features_1 = torch.cat( + intermediate_info_sharing_multi_view_feat[0].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_1) + # Stack the second intermediate features for all views + stacked_intermediate_features_2 = torch.cat( + intermediate_info_sharing_multi_view_feat[1].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_2) + # Stack the third intermediate features for all views + stacked_intermediate_features_3 = torch.cat( + intermediate_info_sharing_multi_view_feat[2].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_3) + # Stack the last layer + stacked_final_features = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + dense_head_inputs_list.append(stacked_final_features) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + with torch.autocast("cuda", enabled=False): + # Prepare inputs for the downstream heads + if self.pred_head_type == "linear": + dense_head_inputs = dense_head_inputs + elif self.pred_head_type in ["dpt", "dpt+pose"]: + dense_head_inputs = dense_head_inputs_list + scale_head_inputs = ( + final_info_sharing_multi_view_feat.additional_token_features + ) + + # Run the downstream heads + dense_final_outputs, pose_final_outputs, scale_final_output = ( + self.downstream_head( + dense_head_inputs=dense_head_inputs, + scale_head_inputs=scale_head_inputs, + img_shape=img_shape, + memory_efficient_inference=memory_efficient_inference, + ) + ) + + # Prepare the final scene representation for all views + if self.scene_rep_type in [ + "pointmap", + "pointmap+confidence", + "pointmap+mask", + "pointmap+confidence+mask", + ]: + output_pts3d = dense_final_outputs.value + # Reshape final scene representation to (B * V, H, W, C) + output_pts3d = output_pts3d.permute(0, 2, 3, 1).contiguous() + # Split the predicted pointmaps back to their respective views + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "metric_scaling_factor": scale_final_output, + } + ) + elif self.scene_rep_type in [ + "raymap+depth", + "raymap+depth+confidence", + "raymap+depth+mask", + "raymap+depth+confidence+mask", + ]: + # Reshape final scene representation to (B * V, H, W, C) + output_scene_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted ray origins, directions, and depths along rays + output_ray_origins, output_ray_directions, output_depth_along_ray = ( + output_scene_rep.split([3, 3, 1], dim=-1) + ) + # Get the predicted pointmaps + output_pts3d = ( + output_ray_origins + output_ray_directions * output_depth_along_ray + ) + # Split the predicted quantities back to their respective views + output_ray_origins_per_view = output_ray_origins.chunk(num_views, dim=0) + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "ray_origins": output_ray_origins_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "metric_scaling_factor": scale_final_output, + } + ) + elif self.scene_rep_type in [ + "raydirs+depth+pose", + "raydirs+depth+pose+confidence", + "raydirs+depth+pose+mask", + "raydirs+depth+pose+confidence+mask", + ]: + # Reshape output dense rep to (B * V, H, W, C) + output_dense_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted ray directions and depths along rays + output_ray_directions, output_depth_along_ray = output_dense_rep.split( + [3, 1], dim=-1 + ) + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the predicted pointmaps in world frame and camera frame + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + output_pts3d_cam = output_ray_directions * output_depth_along_ray + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "pts3d_cam": output_pts3d_cam_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "cam_trans": output_cam_translations_per_view[i] + * scale_final_output, + "cam_quats": output_cam_quats_per_view[i], + "metric_scaling_factor": scale_final_output, + } + ) + elif self.scene_rep_type in [ + "campointmap+pose", + "campointmap+pose+confidence", + "campointmap+pose+mask", + "campointmap+pose+confidence+mask", + ]: + # Get the predicted camera frame pointmaps + output_pts3d_cam = dense_final_outputs.value + # Reshape final scene representation to (B * V, H, W, C) + output_pts3d_cam = output_pts3d_cam.permute(0, 2, 3, 1).contiguous() + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the ray directions and depths along rays + output_depth_along_ray = torch.norm( + output_pts3d_cam, dim=-1, keepdim=True + ) + output_ray_directions = output_pts3d_cam / output_depth_along_ray + # Get the predicted pointmaps in world frame + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "pts3d_cam": output_pts3d_cam_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "cam_trans": output_cam_translations_per_view[i] + * scale_final_output, + "cam_quats": output_cam_quats_per_view[i], + "metric_scaling_factor": scale_final_output, + } + ) + elif self.scene_rep_type in [ + "pointmap+raydirs+depth+pose", + "pointmap+raydirs+depth+pose+confidence", + "pointmap+raydirs+depth+pose+mask", + "pointmap+raydirs+depth+pose+confidence+mask", + ]: + # Reshape final scene representation to (B * V, H, W, C) + output_dense_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted pointmaps, ray directions and depths along rays + output_pts3d, output_ray_directions, output_depth_along_ray = ( + output_dense_rep.split([3, 3, 1], dim=-1) + ) + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the predicted pointmaps in camera frame + output_pts3d_cam = output_ray_directions * output_depth_along_ray + # Replace the predicted world-frame pointmaps if required + if self.pred_head_config["adaptor_config"][ + "use_factored_predictions_for_global_pointmaps" + ]: + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "pts3d_cam": output_pts3d_cam_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "cam_trans": output_cam_translations_per_view[i] + * scale_final_output, + "cam_quats": output_cam_quats_per_view[i], + "metric_scaling_factor": scale_final_output, + } + ) + else: + raise ValueError( + f"Invalid scene_rep_type: {self.scene_rep_type}. \ + Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \ + 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \ + 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \ + 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']" + ) + + # Get the output confidences for all views (if available) and add them to the result + if "confidence" in self.scene_rep_type: + output_confidences = dense_final_outputs.confidence + # Reshape confidences to (B * V, H, W) + output_confidences = ( + output_confidences.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + # Split the predicted confidences back to their respective views + output_confidences_per_view = output_confidences.chunk(num_views, dim=0) + # Add the confidences to the result + for i in range(num_views): + res[i]["conf"] = output_confidences_per_view[i] + + # Get the output masks (and logits) for all views (if available) and add them to the result + if "mask" in self.scene_rep_type: + # Get the output masks + output_masks = dense_final_outputs.mask + # Reshape masks to (B * V, H, W) + output_masks = output_masks.permute(0, 2, 3, 1).squeeze(-1).contiguous() + # Threshold the masks at 0.5 to get binary masks (0: ambiguous, 1: non-ambiguous) + output_masks = output_masks > 0.5 + # Split the predicted masks back to their respective views + output_masks_per_view = output_masks.chunk(num_views, dim=0) + # Get the output mask logits (for loss) + output_mask_logits = dense_final_outputs.logits + # Reshape mask logits to (B * V, H, W) + output_mask_logits = ( + output_mask_logits.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + # Split the predicted mask logits back to their respective views + output_mask_logits_per_view = output_mask_logits.chunk(num_views, dim=0) + # Add the masks and logits to the result + for i in range(num_views): + res[i]["non_ambiguous_mask"] = output_masks_per_view[i] + res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i] + + return res + + def _configure_geometric_input_config( + self, + use_calibration: bool, + use_depth: bool, + use_pose: bool, + use_depth_scale: bool, + use_pose_scale: bool, + ): + """ + Configure the geometric input configuration + """ + # Store original config for restoration + if not hasattr(self, "_original_geometric_config"): + self._original_geometric_config = dict(self.geometric_input_config) + + # Set the geometric input configuration + if not (use_calibration or use_depth or use_pose): + # No geometric inputs (images-only mode) + self.geometric_input_config.update( + { + "overall_prob": 0.0, + "dropout_prob": 1.0, + "ray_dirs_prob": 0.0, + "depth_prob": 0.0, + "cam_prob": 0.0, + "sparse_depth_prob": 0.0, + "depth_scale_norm_all_prob": 0.0, + "pose_scale_norm_all_prob": 0.0, + } + ) + else: + # Enable geometric inputs with deterministic behavior + self.geometric_input_config.update( + { + "overall_prob": 1.0, + "dropout_prob": 0.0, + "ray_dirs_prob": 1.0 if use_calibration else 0.0, + "depth_prob": 1.0 if use_depth else 0.0, + "cam_prob": 1.0 if use_pose else 0.0, + "sparse_depth_prob": 0.0, + "depth_scale_norm_all_prob": 0.0 if use_depth_scale else 1.0, + "pose_scale_norm_all_prob": 0.0 if use_pose_scale else 1.0, + } + ) + + def _restore_original_geometric_input_config(self): + """ + Restore original geometric input configuration + """ + if hasattr(self, "_original_geometric_config"): + self.geometric_input_config.update(self._original_geometric_config) + + @torch.inference_mode() + def infer( + self, + views: List[Dict[str, Any]], + memory_efficient_inference: bool = False, + use_amp: bool = True, + amp_dtype: str = "bf16", + apply_mask: bool = True, + mask_edges: bool = True, + edge_normal_threshold: float = 5.0, + edge_depth_threshold: float = 0.03, + apply_confidence_mask: bool = False, + confidence_percentile: float = 10, + ignore_calibration_inputs: bool = False, + ignore_depth_inputs: bool = False, + ignore_pose_inputs: bool = False, + ignore_depth_scale_inputs: bool = False, + ignore_pose_scale_inputs: bool = False, + ) -> List[Dict[str, torch.Tensor]]: + """ + User-friendly inference with strict input validation and automatic conversion. + + Args: + views: List of view dictionaries. Each dict can contain: + Required: + - 'img': torch.Tensor of shape (B, 3, H, W) - normalized RGB images + - 'data_norm_type': str - normalization type used to normalize the images (must be equal to self.model.encoder.data_norm_type) + + Optional Geometric Inputs (only one of intrinsics OR ray_directions): + - 'intrinsics': torch.Tensor of shape (B, 3, 3) - will be converted to ray directions + - 'ray_directions': torch.Tensor of shape (B, H, W, 3) - ray directions in camera frame + - 'depth_z': torch.Tensor of shape (B, H, W, 1) - Z depth in camera frame (intrinsics or ray_directions must be provided) + - 'camera_poses': torch.Tensor of shape (B, 4, 4) or tuple of (quats - (B, 4), trans - (B, 3)) - can be any world frame + - 'is_metric_scale': bool or torch.Tensor of shape (B,) - if not provided, defaults to True + + Optional Additional Info: + - 'instance': List[str] where length of list is B - instance info for each view + - 'idx': List[int] where length of list is B - index info for each view + - 'true_shape': List[tuple] where length of list is B - true shape info (H, W) for each view + + memory_efficient_inference: Whether to use memory-efficient inference for dense prediction heads (trades off speed). Defaults to False. + use_amp: Whether to use automatic mixed precision for faster inference. Defaults to True. + amp_dtype: The dtype to use for mixed precision. Defaults to "bf16" (bfloat16). Options: "fp16", "bf16", "fp32". + apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True. + mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True. + edge_normal_threshold: Tolerance threshold for normals-based edge detection. Defaults to 5.0. + edge_depth_threshold: Relative tolerance threshold for depth-based edge detection. Defaults to 0.03. + apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False. + confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10. + ignore_calibration_inputs: Whether to ignore the calibration inputs (intrinsics and ray_directions). Defaults to False. + ignore_depth_inputs: Whether to ignore the depth inputs. Defaults to False. + ignore_pose_inputs: Whether to ignore the pose inputs. Defaults to False. + ignore_depth_scale_inputs: Whether to ignore the depth scale inputs. Defaults to False. + ignore_pose_scale_inputs: Whether to ignore the pose scale inputs. Defaults to False. + + IMPORTANT CONSTRAINTS: + - Cannot provide both 'intrinsics' and 'ray_directions' (they represent the same information) + - If 'depth' is provided, then 'intrinsics' or 'ray_directions' must also be provided + - If ANY view has 'camera_poses', then view 0 (first view) MUST also have 'camera_poses' + + Returns: + List of prediction dictionaries, one per view. Each dict contains: + - 'img_no_norm': torch.Tensor of shape (B, H, W, 3) - denormalized rgb images + - 'pts3d': torch.Tensor of shape (B, H, W, 3) - predicted points in world frame + - 'pts3d_cam': torch.Tensor of shape (B, H, W, 3) - predicted points in camera frame + - 'ray_directions': torch.Tensor of shape (B, H, W, 3) - ray directions in camera frame + - 'intrinsics': torch.Tensor of shape (B, 3, 3) - pinhole camera intrinsics recovered from ray directions + - 'depth_along_ray': torch.Tensor of shape (B, H, W, 1) - depth along ray in camera frame + - 'depth_z': torch.Tensor of shape (B, H, W, 1) - Z depth in camera frame + - 'cam_trans': torch.Tensor of shape (B, 3) - camera translation in world frame + - 'cam_quats': torch.Tensor of shape (B, 4) - camera quaternion in world frame + - 'camera_poses': torch.Tensor of shape (B, 4, 4) - camera pose in world frame + - 'metric_scaling_factor': torch.Tensor of shape (B,) - applied metric scaling factor + - 'mask': torch.Tensor of shape (B, H, W, 1) - combo of non-ambiguous mask, edge mask and confidence-based mask if used + - 'non_ambiguous_mask': torch.Tensor of shape (B, H, W) - non-ambiguous mask + - 'non_ambiguous_mask_logits': torch.Tensor of shape (B, H, W) - non-ambiguous mask logits + - 'conf': torch.Tensor of shape (B, H, W) - confidence + + Raises: + ValueError: For invalid inputs, missing required keys, conflicting modalities, or constraint violations + """ + # Determine the mixed precision floating point type + if use_amp: + if amp_dtype == "fp16": + amp_dtype = torch.float16 + elif amp_dtype == "bf16": + if torch.cuda.is_bf16_supported(): + amp_dtype = torch.bfloat16 + else: + warnings.warn( + "bf16 is not supported on this device. Using fp16 instead." + ) + amp_dtype = torch.float16 + elif amp_dtype == "fp32": + amp_dtype = torch.float32 + else: + amp_dtype = torch.float32 + + # Validate the input views + validated_views = validate_input_views_for_inference(views) + + # Transfer the views to the same device as the model + ignore_keys = set( + [ + "instance", + "idx", + "true_shape", + "data_norm_type", + ] + ) + for view in validated_views: + for name in view.keys(): + if name in ignore_keys: + continue + view[name] = view[name].to(self.device, non_blocking=True) + + # Pre-process the input views + processed_views = preprocess_input_views_for_inference(validated_views) + + # Set the model input probabilities based on input args for ignoring inputs + self._configure_geometric_input_config( + use_calibration=not ignore_calibration_inputs, + use_depth=not ignore_depth_inputs, + use_pose=not ignore_pose_inputs, + use_depth_scale=not ignore_depth_scale_inputs, + use_pose_scale=not ignore_pose_scale_inputs, + ) + + # Run the model + with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype): + preds = self.forward( + processed_views, memory_efficient_inference=memory_efficient_inference + ) + + # Post-process the model outputs + preds = postprocess_model_outputs_for_inference( + raw_outputs=preds, + input_views=processed_views, + apply_mask=apply_mask, + mask_edges=mask_edges, + edge_normal_threshold=edge_normal_threshold, + edge_depth_threshold=edge_depth_threshold, + apply_confidence_mask=apply_confidence_mask, + confidence_percentile=confidence_percentile, + ) + + # Restore the original configuration + self._restore_original_geometric_input_config() + + return preds diff --git a/mapanything/models/mapanything/modular_dust3r.py b/mapanything/models/mapanything/modular_dust3r.py new file mode 100644 index 0000000000000000000000000000000000000000..672ac1b9368545c2b714c94abbbe1e35824cb6f7 --- /dev/null +++ b/mapanything/models/mapanything/modular_dust3r.py @@ -0,0 +1,475 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Modular DUSt3R class defined using UniCeption modules. +""" + +from typing import Callable, Dict + +import torch +import torch.nn as nn + +from uniception.models.encoders import encoder_factory, ViTEncoderInput +from uniception.models.info_sharing.alternating_attention_transformer import ( + MultiViewAlternatingAttentionTransformer, + MultiViewAlternatingAttentionTransformerIFR, +) +from uniception.models.info_sharing.base import MultiViewTransformerInput +from uniception.models.info_sharing.cross_attention_transformer import ( + MultiViewCrossAttentionTransformer, + MultiViewCrossAttentionTransformerIFR, +) +from uniception.models.info_sharing.global_attention_transformer import ( + MultiViewGlobalAttentionTransformer, + MultiViewGlobalAttentionTransformerIFR, +) +from uniception.models.libs.croco.pos_embed import RoPE2D +from uniception.models.prediction_heads.adaptors import PointMapWithConfidenceAdaptor +from uniception.models.prediction_heads.base import ( + AdaptorInput, + PredictionHeadInput, + PredictionHeadLayeredInput, +) +from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor +from uniception.models.prediction_heads.linear import LinearFeature + +# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) +if hasattr(torch.backends.cuda, "matmul") and hasattr( + torch.backends.cuda.matmul, "allow_tf32" +): + torch.backends.cuda.matmul.allow_tf32 = True + + +class ModularDUSt3R(nn.Module): + "Modular DUSt3R model class." + + def __init__( + self, + name: str, + encoder_config: Dict, + info_sharing_config: Dict, + pred_head_config: Dict, + pretrained_checkpoint_path: str = None, + load_specific_pretrained_submodules: bool = False, + specific_pretrained_submodules: list = [], + torch_hub_force_reload: bool = False, + *args, + **kwargs, + ): + """ + Two-view model containing siamese encoders followed by a two-view attention transformer and respective downstream heads. + The goal is to output scene representation directly, both outputs in view1's frame (hence the asymmetry). + + Args: + name (str): Name of the model. + encoder_config (Dict): Configuration for the encoder. + info_sharing_config (Dict): Configuration for the two-view attention transformer. + pred_head_config (Dict): Configuration for the prediction heads. + pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None) + load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False) + specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: []) + torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False) + """ + super().__init__(*args, **kwargs) + + # Initialize the attributes + self.name = name + self.encoder_config = encoder_config + self.info_sharing_config = info_sharing_config + self.pred_head_config = pred_head_config + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.load_specific_pretrained_submodules = load_specific_pretrained_submodules + self.specific_pretrained_submodules = specific_pretrained_submodules + self.torch_hub_force_reload = torch_hub_force_reload + self.class_init_args = { + "name": self.name, + "encoder_config": self.encoder_config, + "info_sharing_config": self.info_sharing_config, + "pred_head_config": self.pred_head_config, + "pretrained_checkpoint_path": self.pretrained_checkpoint_path, + "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules, + "specific_pretrained_submodules": self.specific_pretrained_submodules, + "torch_hub_force_reload": self.torch_hub_force_reload, + } + + # Get relevant parameters from the configs + custom_positional_encoding = info_sharing_config["custom_positional_encoding"] + self.info_sharing_type = info_sharing_config["model_type"] + self.info_sharing_return_type = info_sharing_config["model_return_type"] + self.pred_head_type = pred_head_config["type"] + + # Initialize Encoder + if self.encoder_config["uses_torch_hub"]: + self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload + # Create a copy of the config before deleting the key to preserve it for serialization + encoder_config_copy = self.encoder_config.copy() + del encoder_config_copy["uses_torch_hub"] + self.encoder = encoder_factory(**encoder_config_copy) + + # Initialize Custom Positional Encoding if required + if custom_positional_encoding is not None: + if isinstance(custom_positional_encoding, str): + print( + f"Using custom positional encoding for multi-view cross attention transformer: {custom_positional_encoding}" + ) + if custom_positional_encoding.startswith("RoPE"): + rope_freq = float(custom_positional_encoding[len("RoPE") :]) + print(f"RoPE frequency: {rope_freq}") + self.custom_positional_encoding = RoPE2D(freq=rope_freq) + else: + raise ValueError( + f"Invalid custom_positional_encoding: {custom_positional_encoding}." + ) + elif isinstance(custom_positional_encoding, Callable): + print( + "Using callable function as custom positional encoding for multi-view cross attention transformer." + ) + self.custom_positional_encoding = custom_positional_encoding + else: + self.custom_positional_encoding = None + + # Add dependencies to info_sharing_config + info_sharing_config["module_args"]["input_embed_dim"] = ( + self.encoder.enc_embed_dim + ) + info_sharing_config["module_args"]["custom_positional_encoding"] = ( + self.custom_positional_encoding + ) + + # Initialize Multi-View Transformer + if self.info_sharing_return_type == "no_intermediate_features": + # Returns only normalized last layer features + # Initialize multi-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformer( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + elif self.info_sharing_return_type == "intermediate_features": + # Returns intermediate features and normalized last layer features + # Initialize mulit-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + # Assess if the DPT needs to use encoder features + if len(self.info_sharing.indices) == 2: + self.use_encoder_features_for_dpt = True + elif len(self.info_sharing.indices) == 3: + self.use_encoder_features_for_dpt = False + else: + raise ValueError( + "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices." + ) + else: + raise ValueError( + f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']" + ) + + # Add dependencies to prediction head config + pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size + if self.pred_head_type == "linear": + pred_head_config["feature_head"]["input_feature_dim"] = ( + self.info_sharing.dim + ) + elif self.pred_head_type == "dpt": + if self.use_encoder_features_for_dpt: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.encoder.enc_embed_dim + ] + [self.info_sharing.dim] * 3 + else: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.info_sharing.dim + ] * 4 + pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[ + "feature_head" + ]["feature_dim"] + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt']" + ) + + # Initialize Prediction Heads + if self.pred_head_type == "linear": + # Initialize Prediction Head 1 + self.head1 = LinearFeature(**pred_head_config["feature_head"]) + # Initialize Prediction Head 2 + self.head2 = LinearFeature(**pred_head_config["feature_head"]) + elif self.pred_head_type == "dpt": + # Initialize Prediction Head 1 + self.dpt_feature_head1 = DPTFeature(**pred_head_config["feature_head"]) + self.dpt_regressor_head1 = DPTRegressionProcessor( + **pred_head_config["regressor_head"] + ) + self.head1 = nn.Sequential(self.dpt_feature_head1, self.dpt_regressor_head1) + # Initialize Prediction Head 2 + self.dpt_feature_head2 = DPTFeature(**pred_head_config["feature_head"]) + self.dpt_regressor_head2 = DPTRegressionProcessor( + **pred_head_config["regressor_head"] + ) + self.head2 = nn.Sequential(self.dpt_feature_head2, self.dpt_regressor_head2) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt']" + ) + + # Initialize Final Output Adaptor + if pred_head_config["adaptor_type"] == "pointmap+confidence": + self.adaptor = PointMapWithConfidenceAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "pointmap" + else: + raise ValueError( + f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. Valid options: ['pointmap+confidence']" + ) + + # Load pretrained weights + if self.pretrained_checkpoint_path is not None: + if not self.load_specific_pretrained_submodules: + print( + f"Loading pretrained weights from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + else: + print( + f"Loading pretrained weights from {self.pretrained_checkpoint_path} for specific submodules: {specific_pretrained_submodules} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + filtered_ckpt = {} + for ckpt_key, ckpt_value in ckpt["model"].items(): + for submodule in specific_pretrained_submodules: + if ckpt_key.startswith(submodule): + filtered_ckpt[ckpt_key] = ckpt_value + print(self.load_state_dict(filtered_ckpt, strict=False)) + + def _encode_image_pairs(self, img1, img2, data_norm_type): + "Encode two different batches of images (each batch can have different image shape)" + if img1.shape[-2:] == img2.shape[-2:]: + encoder_input = ViTEncoderInput( + image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type + ) + encoder_output = self.encoder(encoder_input) + out, out2 = encoder_output.features.chunk(2, dim=0) + else: + encoder_input = ViTEncoderInput(image=img1, data_norm_type=data_norm_type) + out = self.encoder(encoder_input) + out = out.features + encoder_input2 = ViTEncoderInput(image=img2, data_norm_type=data_norm_type) + out2 = self.encoder(encoder_input2) + out2 = out2.features + + return out, out2 + + def _encode_symmetrized(self, view1, view2): + "Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input" + img1 = view1["img"] + img2 = view2["img"] + if isinstance(view1["data_norm_type"], list): + assert all( + [x == view1["data_norm_type"][0] for x in view1["data_norm_type"]] + ), "All data_norm_type values should be the same in the list." + data_norm_type = view1["data_norm_type"][0] + elif isinstance(view1["data_norm_type"], str): + data_norm_type = view1["data_norm_type"] + else: + raise ValueError( + f"Invalid data_norm_type: {view1['data_norm_type']}. Should be either a list with all same values or a string." + ) + feat1, feat2 = self._encode_image_pairs( + img1, img2, data_norm_type=data_norm_type + ) + + return feat1, feat2 + + def _downstream_head(self, head_num, decout, img_shape): + "Run the respective prediction heads" + head = getattr(self, f"head{head_num}") + if self.pred_head_type == "linear": + head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) + elif self.pred_head_type == "dpt": + head_input = PredictionHeadLayeredInput( + list_features=decout[f"{head_num}"], target_output_shape=img_shape + ) + + return head(head_input) + + def forward(self, views): + """ + Forward pass performing the following operations: + 1. Encodes the two input views (images). + 2. Combines the encoded features using a two-view attention transformer. + 3. Passes the combined features through the respective prediction heads. + 4. Returns the processed final outputs for both views. + + Args: + views (List(dict)): A list of size two whose elements are: + view1 (dict): Dictionary containing the first view's images and instance information. + "img" is a required key and value is a tensor of shape (B, C, H, W). + view2 (dict): Dictionary containing the second view's images and instance information. + "img" is a required key and value is a tensor of shape (B, C, H, W). + + Returns: + List[dict, dict]: A list containing the final outputs for both views. + """ + # Get input shapes + view1 = views[0] + view2 = views[1] + _, _, height1, width1 = view1["img"].shape + _, _, height2, width2 = view2["img"].shape + shape1 = (int(height1), int(width1)) + shape2 = (int(height2), int(width2)) + + if "img_encoder_feats" in view1 and "img_encoder_feats" in view2: + # Reuse the pre-computed image features for the two views + feat1 = view1["img_encoder_feats"] + feat2 = view2["img_encoder_feats"] + else: + # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) + feat1, feat2 = self._encode_symmetrized(view1, view2) + + # Combine all images into view-centric representation + info_sharing_input = MultiViewTransformerInput(features=[feat1, feat2]) + if self.info_sharing_return_type == "no_intermediate_features": + final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input) + elif self.info_sharing_return_type == "intermediate_features": + ( + final_info_sharing_multi_view_feat, + intermediate_info_sharing_multi_view_feat, + ) = self.info_sharing(info_sharing_input) + + if self.pred_head_type == "linear": + # Define feature dictionary for linear head + info_sharing_outputs = { + "1": final_info_sharing_multi_view_feat.features[0].float(), + "2": final_info_sharing_multi_view_feat.features[1].float(), + } + elif self.pred_head_type == "dpt": + # Define feature dictionary for DPT head + if self.use_encoder_features_for_dpt: + info_sharing_outputs = { + "1": [ + feat1.float(), + intermediate_info_sharing_multi_view_feat[0] + .features[0] + .float(), + intermediate_info_sharing_multi_view_feat[1] + .features[0] + .float(), + final_info_sharing_multi_view_feat.features[0].float(), + ], + "2": [ + feat2.float(), + intermediate_info_sharing_multi_view_feat[0] + .features[1] + .float(), + intermediate_info_sharing_multi_view_feat[1] + .features[1] + .float(), + final_info_sharing_multi_view_feat.features[1].float(), + ], + } + else: + info_sharing_outputs = { + "1": [ + intermediate_info_sharing_multi_view_feat[0] + .features[0] + .float(), + intermediate_info_sharing_multi_view_feat[1] + .features[0] + .float(), + intermediate_info_sharing_multi_view_feat[2] + .features[0] + .float(), + final_info_sharing_multi_view_feat.features[0].float(), + ], + "2": [ + intermediate_info_sharing_multi_view_feat[0] + .features[1] + .float(), + intermediate_info_sharing_multi_view_feat[1] + .features[1] + .float(), + intermediate_info_sharing_multi_view_feat[2] + .features[1] + .float(), + final_info_sharing_multi_view_feat.features[1].float(), + ], + } + + # Downstream task prediction + with torch.autocast("cuda", enabled=False): + # Prediction heads + head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) + head_output2 = self._downstream_head(2, info_sharing_outputs, shape2) + + # Post-process outputs + final_output1 = self.adaptor( + AdaptorInput( + adaptor_feature=head_output1.decoded_channels, + output_shape_hw=shape1, + ) + ) + final_output2 = self.adaptor( + AdaptorInput( + adaptor_feature=head_output2.decoded_channels, + output_shape_hw=shape2, + ) + ) + + # Reshape final scene representation to (B, H, W, C) + final_scene_rep1 = final_output1.value.permute(0, 2, 3, 1).contiguous() + final_scene_rep2 = final_output2.value.permute(0, 2, 3, 1).contiguous() + + # Convert output scene representation to pointmaps + if self.scene_rep_type == "pointmap": + output_pts3d1 = final_scene_rep1 + output_pts3d2 = final_scene_rep2 + else: + raise ValueError(f"Invalid scene_rep_type: {self.scene_rep_type}.") + + # Reshape confidence to (B, H, W, 1) + output_conf1 = ( + final_output1.confidence.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + output_conf2 = ( + final_output2.confidence.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + + # Convert outputs to dictionary + res1 = { + "pts3d": output_pts3d1, + "conf": output_conf1, + } + res2 = { + "pts3d": output_pts3d2, + "conf": output_conf2, + } + res = [res1, res2] + + return res diff --git a/mapanything/third_party/README.md b/mapanything/third_party/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6c1bbc96bbd47274068555bef2fa31335a88468d --- /dev/null +++ b/mapanything/third_party/README.md @@ -0,0 +1,3 @@ +# Third Party Code + +This folder contains third party code from VGGSfM & VGGT to support the COLMAP demo. diff --git a/mapanything/third_party/__init__.py b/mapanything/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/third_party/distortion.py b/mapanything/third_party/distortion.py new file mode 100644 index 0000000000000000000000000000000000000000..98d42e6674e351ae055959f8a0931ef1806b621f --- /dev/null +++ b/mapanything/third_party/distortion.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +from typing import Union + +import numpy as np +import torch + +ArrayLike = Union[np.ndarray, torch.Tensor] + + +def _is_numpy(x: ArrayLike) -> bool: + return isinstance(x, np.ndarray) + + +def _is_torch(x: ArrayLike) -> bool: + return isinstance(x, torch.Tensor) + + +def _ensure_torch(x: ArrayLike) -> torch.Tensor: + """Convert input to torch tensor if it's not already one.""" + if _is_numpy(x): + return torch.from_numpy(x) + elif _is_torch(x): + return x + else: + return torch.tensor(x) + + +def single_undistortion(params, tracks_normalized): + """ + Apply undistortion to the normalized tracks using the given distortion parameters once. + + Args: + params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. + tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. + + Returns: + torch.Tensor: Undistorted normalized tracks tensor. + """ + params = _ensure_torch(params) + tracks_normalized = _ensure_torch(tracks_normalized) + + u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() + u_undist, v_undist = apply_distortion(params, u, v) + return torch.stack([u_undist, v_undist], dim=-1) + + +def iterative_undistortion( + params, + tracks_normalized, + max_iterations=100, + max_step_norm=1e-10, + rel_step_size=1e-6, +): + """ + Iteratively undistort the normalized tracks using the given distortion parameters. + + Args: + params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. + tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. + max_iterations (int): Maximum number of iterations for the undistortion process. + max_step_norm (float): Maximum step norm for convergence. + rel_step_size (float): Relative step size for numerical differentiation. + + Returns: + torch.Tensor: Undistorted normalized tracks tensor. + """ + params = _ensure_torch(params) + tracks_normalized = _ensure_torch(tracks_normalized) + + B, N, _ = tracks_normalized.shape + u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() + original_u, original_v = u.clone(), v.clone() + + eps = torch.finfo(u.dtype).eps + for idx in range(max_iterations): + u_undist, v_undist = apply_distortion(params, u, v) + dx = original_u - u_undist + dy = original_v - v_undist + + step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps) + step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps) + + J_00 = ( + apply_distortion(params, u + step_u, v)[0] + - apply_distortion(params, u - step_u, v)[0] + ) / (2 * step_u) + J_01 = ( + apply_distortion(params, u, v + step_v)[0] + - apply_distortion(params, u, v - step_v)[0] + ) / (2 * step_v) + J_10 = ( + apply_distortion(params, u + step_u, v)[1] + - apply_distortion(params, u - step_u, v)[1] + ) / (2 * step_u) + J_11 = ( + apply_distortion(params, u, v + step_v)[1] + - apply_distortion(params, u, v - step_v)[1] + ) / (2 * step_v) + + J = torch.stack( + [ + torch.stack([J_00 + 1, J_01], dim=-1), + torch.stack([J_10, J_11 + 1], dim=-1), + ], + dim=-2, + ) + + delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1)) + + u += delta[..., 0] + v += delta[..., 1] + + if torch.max((delta**2).sum(dim=-1)) < max_step_norm: + break + + return torch.stack([u, v], dim=-1) + + +def apply_distortion(extra_params, u, v): + """ + Applies radial or OpenCV distortion to the given 2D points. + + Args: + extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks. + v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks. + + Returns: + points2D (torch.Tensor): Distorted 2D points of shape BxNx2. + """ + extra_params = _ensure_torch(extra_params) + u = _ensure_torch(u) + v = _ensure_torch(v) + + num_params = extra_params.shape[1] + + if num_params == 1: + # Simple radial distortion + k = extra_params[:, 0] + u2 = u * u + v2 = v * v + r2 = u2 + v2 + radial = k[:, None] * r2 + du = u * radial + dv = v * radial + + elif num_params == 2: + # RadialCameraModel distortion + k1, k2 = extra_params[:, 0], extra_params[:, 1] + u2 = u * u + v2 = v * v + r2 = u2 + v2 + radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 + du = u * radial + dv = v * radial + + elif num_params == 4: + # OpenCVCameraModel distortion + k1, k2, p1, p2 = ( + extra_params[:, 0], + extra_params[:, 1], + extra_params[:, 2], + extra_params[:, 3], + ) + u2 = u * u + v2 = v * v + uv = u * v + r2 = u2 + v2 + radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 + du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2) + dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2) + else: + raise ValueError("Unsupported number of distortion parameters") + + u = u.clone() + du + v = v.clone() + dv + + return u, v + + +if __name__ == "__main__": + import random + + import pycolmap + + max_diff = 0 + for i in range(1000): + # Define distortion parameters (assuming 1 parameter for simplicity) + B = random.randint(1, 500) + track_num = random.randint(100, 1000) + params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters + tracks_normalized = torch.rand( + (B, track_num, 2), dtype=torch.float32 + ) # Batch size 1, 5 points + + # Undistort the tracks + undistorted_tracks = iterative_undistortion(params, tracks_normalized) + + for b in range(B): + pycolmap_intri = np.array([1, 0, 0, params[b].item()]) + pycam = pycolmap.Camera( + model="SIMPLE_RADIAL", + width=1, + height=1, + params=pycolmap_intri, + camera_id=0, + ) + + undistorted_tracks_pycolmap = pycam.cam_from_img( + tracks_normalized[b].numpy() + ) + diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median() + max_diff = max(max_diff, diff) + print(f"diff: {diff}, max_diff: {max_diff}") + + import pdb + + pdb.set_trace() diff --git a/mapanything/third_party/np_to_pycolmap.py b/mapanything/third_party/np_to_pycolmap.py new file mode 100644 index 0000000000000000000000000000000000000000..5349f965f6437e245a0c911ce7334e493fd6aa8c --- /dev/null +++ b/mapanything/third_party/np_to_pycolmap.py @@ -0,0 +1,357 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import numpy as np +import pycolmap + +from mapanything.third_party.projection import project_3D_points_np + + +def batch_np_matrix_to_pycolmap( + points3d, + extrinsics, + intrinsics, + tracks, + image_size, + masks=None, + max_reproj_error=None, + max_points3D_val=3000, + shared_camera=False, + camera_type="SIMPLE_PINHOLE", + extra_params=None, + min_inlier_per_frame=64, + points_rgb=None, +): + """ + Convert Batched NumPy Arrays to PyCOLMAP + + Check https://github.com/colmap/pycolmap for more details about its format + + NOTE that colmap expects images/cameras/points3D to be 1-indexed + so there is a +1 offset between colmap index and batch index + + + NOTE: different from VGGSfM, this function: + 1. Use np instead of torch + 2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP) + """ + # points3d: Px3 + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # tracks: NxPx2 + # masks: NxP + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N, P, _ = tracks.shape + assert len(extrinsics) == N + assert len(intrinsics) == N + assert len(points3d) == P + assert image_size.shape[0] == 2 + + reproj_mask = None + + if max_reproj_error is not None: + projected_points_2d, projected_points_cam = project_3D_points_np( + points3d, extrinsics, intrinsics + ) + projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1) + projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6 + reproj_mask = projected_diff < max_reproj_error + + if masks is not None and reproj_mask is not None: + masks = np.logical_and(masks, reproj_mask) + elif masks is not None: + masks = masks + else: + masks = reproj_mask + + assert masks is not None + + if masks.sum(1).min() < min_inlier_per_frame: + print("Not enough inliers per frame, skip BA.") + return None, None + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + inlier_num = masks.sum(0) + valid_mask = inlier_num >= 2 # a track is invalid if without two inliers + valid_idx = np.nonzero(valid_mask)[0] + + # Only add 3D points that have sufficient 2D points + for vidx in valid_idx: + # Use RGB colors if provided, otherwise use zeros + rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3) + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb) + + num_points3D = len(valid_idx) + camera = None + # frame idx + for fidx in range(N): + # set camera + if camera is None or (not shared_camera): + pycolmap_intri = _build_pycolmap_intri( + fidx, intrinsics, camera_type, extra_params + ) + + camera = pycolmap.Camera( + model=camera_type, + width=image_size[0], + height=image_size[1], + params=pycolmap_intri, + camera_id=fidx + 1, + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + + image = pycolmap.Image( + id=fidx + 1, + name=f"image_{fidx + 1}", + camera_id=camera.camera_id, + cam_from_world=cam_from_world, + ) + + points2D_list = [] + + point2D_idx = 0 + + # NOTE point3D_id start by 1 + for point3D_id in range(1, num_points3D + 1): + original_track_idx = valid_idx[point3D_id - 1] + + if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all(): + if masks[fidx][original_track_idx]: + # It seems we don't need +0.5 for BA + point2D_xy = tracks[fidx][original_track_idx] + # Please note when adding the Point2D object + # It not only requires the 2D xy location, but also the id to 3D point + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + track.add_element(fidx + 1, point2D_idx) + point2D_idx += 1 + + assert point2D_idx == len(points2D_list) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + image.registered = True + except: # noqa + print(f"frame {fidx + 1} is out of BA") + image.registered = False + + # add image + reconstruction.add_image(image) + + return reconstruction, valid_mask + + +def pycolmap_to_batch_np_matrix( + reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE" +): + """ + Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays. + + Args: + reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP. + device (str): Ignored in NumPy version (kept for API compatibility). + camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE"). + + Returns: + tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params. + """ + + num_images = len(reconstruction.images) + max_points3D_id = max(reconstruction.point3D_ids()) + points3D = np.zeros((max_points3D_id, 3)) + + for point3D_id in reconstruction.points3D: + points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz + + extrinsics = [] + intrinsics = [] + + extra_params = [] if camera_type == "SIMPLE_RADIAL" else None + + for i in range(num_images): + # Extract and append extrinsics + pyimg = reconstruction.images[i + 1] + pycam = reconstruction.cameras[pyimg.camera_id] + matrix = pyimg.cam_from_world.matrix() + extrinsics.append(matrix) + + # Extract and append intrinsics + calibration_matrix = pycam.calibration_matrix() + intrinsics.append(calibration_matrix) + + if camera_type == "SIMPLE_RADIAL": + extra_params.append(pycam.params[-1]) + + # Convert lists to NumPy arrays instead of torch tensors + extrinsics = np.stack(extrinsics) + intrinsics = np.stack(intrinsics) + + if camera_type == "SIMPLE_RADIAL": + extra_params = np.stack(extra_params) + extra_params = extra_params[:, None] + + return points3D, extrinsics, intrinsics, extra_params + + +######################################################## + + +def batch_np_matrix_to_pycolmap_wo_track( + points3d, + points_xyf, + points_rgb, + extrinsics, + intrinsics, + image_size, + shared_camera=False, + camera_type="SIMPLE_PINHOLE", +): + """ + Convert Batched NumPy Arrays to PyCOLMAP + + Different from batch_np_matrix_to_pycolmap, this function does not use tracks. + + It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods. + + Do NOT use this for BA. + """ + # points3d: Px3 + # points_xyf: Px3, with x, y coordinates and frame indices + # points_rgb: Px3, rgb colors + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N = len(extrinsics) + P = len(points3d) + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + for vidx in range(P): + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx]) + + camera = None + # frame idx + for fidx in range(N): + # set camera + if camera is None or (not shared_camera): + pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type) + + camera = pycolmap.Camera( + model=camera_type, + width=image_size[0], + height=image_size[1], + params=pycolmap_intri, + camera_id=fidx + 1, + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + + image = pycolmap.Image( + id=fidx + 1, + name=f"image_{fidx + 1}", + camera_id=camera.camera_id, + cam_from_world=cam_from_world, + ) + + points2D_list = [] + + point2D_idx = 0 + + points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx + points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0] + + for point3D_batch_idx in points_belong_to_fidx: + point3D_id = point3D_batch_idx + 1 + point2D_xyf = points_xyf[point3D_batch_idx] + point2D_xy = point2D_xyf[:2] + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + track.add_element(fidx + 1, point2D_idx) + point2D_idx += 1 + + assert point2D_idx == len(points2D_list) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + image.registered = True + except: # noqa + print(f"frame {fidx + 1} does not have any points") + image.registered = False + + # add image + reconstruction.add_image(image) + + return reconstruction + + +def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None): + """ + Helper function to get camera parameters based on camera type. + + Args: + fidx: Frame index + intrinsics: Camera intrinsic parameters + camera_type: Type of camera model + extra_params: Additional parameters for certain camera types + + Returns: + pycolmap_intri: NumPy array of camera parameters + """ + if camera_type == "PINHOLE": + pycolmap_intri = np.array( + [ + intrinsics[fidx][0, 0], + intrinsics[fidx][1, 1], + intrinsics[fidx][0, 2], + intrinsics[fidx][1, 2], + ] + ) + elif camera_type == "SIMPLE_PINHOLE": + focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 + pycolmap_intri = np.array( + [focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]] + ) + elif camera_type == "SIMPLE_RADIAL": + raise NotImplementedError("SIMPLE_RADIAL is not supported yet") + focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 + pycolmap_intri = np.array( + [ + focal, + intrinsics[fidx][0, 2], + intrinsics[fidx][1, 2], + extra_params[fidx][0], + ] + ) + else: + raise ValueError(f"Camera type {camera_type} is not supported yet") + + return pycolmap_intri diff --git a/mapanything/third_party/projection.py b/mapanything/third_party/projection.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7a6657205505c51d9824487cc3070e95c665f7 --- /dev/null +++ b/mapanything/third_party/projection.py @@ -0,0 +1,250 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import numpy as np +import torch + +from .distortion import apply_distortion + + +def img_from_cam_np( + intrinsics: np.ndarray, + points_cam: np.ndarray, + extra_params: np.ndarray | None = None, + default: float = 0.0, +) -> np.ndarray: + """ + Apply intrinsics (and optional radial distortion) to camera-space points. + + Args + ---- + intrinsics : (B,3,3) camera matrix K. + points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ. + extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None. + default : value used for np.nan replacement. + + Returns + ------- + points2D : (B,N,2) pixel coordinates. + """ + # 1. perspective divide ─────────────────────────────────────── + z = points_cam[:, 2:3, :] # (B,1,N) + points_cam_norm = points_cam / z # (B,3,N) + uv = points_cam_norm[:, :2, :] # (B,2,N) + + # 2. optional distortion ────────────────────────────────────── + if extra_params is not None: + uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) + uv = np.stack([uu, vv], axis=1) # (B,2,N) + + # 3. homogeneous coords then K multiplication ───────────────── + ones = np.ones_like(uv[:, :1, :]) # (B,1,N) + points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N) + + # batched mat-mul: K · [u v 1]ᵀ + points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N) + points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N) + + return points2D.transpose(0, 2, 1) # (B,N,2) + + +def project_3D_points_np( + points3D: np.ndarray, + extrinsics: np.ndarray, + intrinsics: np.ndarray | None = None, + extra_params: np.ndarray | None = None, + *, + default: float = 0.0, + only_points_cam: bool = False, +): + """ + NumPy clone of ``project_3D_points``. + + Parameters + ---------- + points3D : (N,3) world-space points. + extrinsics : (B,3,4) [R|t] matrix for each of B cameras. + intrinsics : (B,3,3) K matrix (optional if you only need cam-space). + extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None. + default : value used to replace NaNs. + only_points_cam : if True, skip the projection and return points_cam with points2D as None. + + Returns + ------- + (points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True, + and points_cam is (B,3,N) camera-space coordinates. + """ + # ----- 0. prep sizes ----------------------------------------------------- + N = points3D.shape[0] # #points + B = extrinsics.shape[0] # #cameras + + # ----- 1. world → homogeneous ------------------------------------------- + w_h = np.ones((N, 1), dtype=points3D.dtype) + points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4) + + # broadcast to every camera (no actual copying with np.broadcast_to) ------ + points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4) + + # ----- 2. apply extrinsics (camera frame) ------------------------------ + # X_cam = E · X_hom + # einsum: E_(b i j) · X_(b n j) → (b n i) + points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3) + points_cam = points_cam.transpose(0, 2, 1) # (B,3,N) + + if only_points_cam: + return None, points_cam + + # ----- 3. intrinsics + distortion --------------------------------------- + if intrinsics is None: + raise ValueError("`intrinsics` must be provided unless only_points_cam=True") + + points2D = img_from_cam_np( + intrinsics, points_cam, extra_params=extra_params, default=default + ) + + return points2D, points_cam + + +def project_3D_points( + points3D, + extrinsics, + intrinsics=None, + extra_params=None, + default=0, + only_points_cam=False, +): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + points3D (torch.Tensor): 3D points of shape Px3. + extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. + intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. + extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion. + default (float): Default value to replace NaNs. + only_points_cam (bool): If True, skip the projection and return points2D as None. + + Returns: + tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True, + and points_cam is of shape Bx3xN. + """ + with torch.cuda.amp.autocast(dtype=torch.double): + B = extrinsics.shape[0] # Batch size, i.e., number of cameras + points3D_homogeneous = torch.cat( + [points3D, torch.ones_like(points3D[..., 0:1])], dim=1 + ) # Nx4 + # Reshape for batch processing + points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand( + B, -1, -1 + ) # BxNx4 + + # Step 1: Apply extrinsic parameters + # Transform 3D points to camera coordinate system for all cameras + points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2)) + + if only_points_cam: + return None, points_cam + + # Step 2: Apply intrinsic parameters and (optional) distortion + points2D = img_from_cam(intrinsics, points_cam, extra_params, default) + + return points2D, points_cam + + +def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0): + """ + Applies intrinsic parameters and optional distortion to the given 3D points. + + Args: + intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. + points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. + extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + default (float, optional): Default value to replace NaNs in the output. + + Returns: + points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. + """ + + # Normalize by the third coordinate (homogeneous division) + points_cam = points_cam / points_cam[:, 2:3, :] + # Extract uv + uv = points_cam[:, :2, :] + + # Apply distortion if extra_params are provided + if extra_params is not None: + uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) + uv = torch.stack([uu, vv], dim=1) + + # Prepare points_cam for batch matrix multiplication + points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN + # Apply intrinsic parameters using batch matrix multiplication + points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN + + # Extract x and y coordinates + points2D = points2D_homo[:, :2, :] # Bx2xN + + # Replace NaNs with default value + points2D = torch.nan_to_num(points2D, nan=default) + + return points2D.transpose(1, 2) # BxNx2 + + +if __name__ == "__main__": + # Set up example input + B, N = 24, 10240 + + for _ in range(100): + points3D = np.random.rand(N, 3).astype(np.float64) + extrinsics = np.random.rand(B, 3, 4).astype(np.float64) + intrinsics = np.random.rand(B, 3, 3).astype(np.float64) + + # Convert to torch tensors + points3D_torch = torch.tensor(points3D) + extrinsics_torch = torch.tensor(extrinsics) + intrinsics_torch = torch.tensor(intrinsics) + + # Run NumPy implementation + points2D_np, points_cam_np = project_3D_points_np( + points3D, extrinsics, intrinsics + ) + + # Run torch implementation + points2D_torch, points_cam_torch = project_3D_points( + points3D_torch, extrinsics_torch, intrinsics_torch + ) + + # Convert torch output to numpy + points2D_torch_np = points2D_torch.detach().numpy() + points_cam_torch_np = points_cam_torch.detach().numpy() + + # Compute difference + diff = np.abs(points2D_np - points2D_torch_np) + print("Difference between NumPy and PyTorch implementations:") + print(diff) + + # Check max error + max_diff = np.max(diff) + print(f"Maximum difference: {max_diff}") + + if np.allclose(points2D_np, points2D_torch_np, atol=1e-6): + print("Implementations match closely.") + else: + print("Significant differences detected.") + + if points_cam_np is not None: + points_cam_diff = np.abs(points_cam_np - points_cam_torch_np) + print("Difference between NumPy and PyTorch camera-space coordinates:") + print(points_cam_diff) + + # Check max error + max_cam_diff = np.max(points_cam_diff) + print(f"Maximum camera-space coordinate difference: {max_cam_diff}") + + if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6): + print("Camera-space coordinates match closely.") + else: + print("Significant differences detected in camera-space coordinates.") diff --git a/mapanything/third_party/track_modules/__init__.py b/mapanything/third_party/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/third_party/track_modules/base_track_predictor.py b/mapanything/third_party/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..4133fc702175d0c4cbcbf996744a54c1ae8c21c5 --- /dev/null +++ b/mapanything/third_party/track_modules/base_track_predictor.py @@ -0,0 +1,212 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import torch +import torch.nn as nn +from einops import rearrange + +from .blocks import CorrBlock, EfficientUpdateFormer +from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=4, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + fine=False, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.fine = fine + + self.flows_emb_dim = latent_dim // 2 + self.transformer_dim = ( + self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2 + ) + + if self.fine: + # TODO this is the old dummy code, will remove this when we train next model + self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5 + else: + self.transformer_dim += (4 - self.transformer_dim % 4) % 4 + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), nn.GELU() + ) + + if not self.fine: + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward( + self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1 + ): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2 + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + # Construct the correlation block + + fcorr_fn = CorrBlock( + fmaps, num_levels=self.corr_levels, radius=self.corr_radius + ) + + coord_preds = [] + + # Iterative Refinement + for itr in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + # Compute the correlation (check the implementation of CorrBlock) + + fcorr_fn.corr(track_feats) + fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim + + corrdim = fcorrs.shape[3] + + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape( + B * N, S, self.latent_dim + ) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + if transformer_input.shape[2] < self.transformer_dim: + # pad the features to match the dimension + pad_dim = self.transformer_dim - transformer_input.shape[2] + pad = torch.zeros_like(flows_emb[..., 0:pad_dim]) + transformer_input = torch.cat([transformer_input, pad], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed( + self.transformer_dim, grid_size=(HH, WW) + ).to(query_points.device) + sampled_pos_emb = sample_features4d( + pos_embed.expand(B, -1, -1, -1), coords[:, 0] + ) + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze( + 1 + ) + + x = transformer_input + sampled_pos_emb + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta = self.updateformer(x) + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_ + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute( + 0, 2, 1, 3 + ) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + if not self.fine: + vis_e = self.vis_predictor( + track_feats.reshape(B * S * N, self.latent_dim) + ).reshape(B, S, N) + vis_e = torch.sigmoid(vis_e) + else: + vis_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat + else: + return coord_preds, vis_e diff --git a/mapanything/third_party/track_modules/blocks.py b/mapanything/third_party/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..72a87809db07a5228b0837377ba8d991625ebcf5 --- /dev/null +++ b/mapanything/third_party/track_modules/blocks.py @@ -0,0 +1,389 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modules import AttnBlock, CrossAttnBlock, ResidualBlock +from .utils import bilinear_sampler + + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, stride=4): + super(BasicEncoder, self).__init__() + + self.stride = stride + self.norm_fn = "instance" + self.in_planes = output_dim // 2 + + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + + self.conv1 = nn.Conv2d( + input_dim, + self.in_planes, + kernel_size=7, + stride=2, + padding=3, + padding_mode="zeros", + ) + self.relu1 = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(output_dim // 2, stride=1) + self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) + self.layer3 = self._make_layer(output_dim, stride=2) + self.layer4 = self._make_layer(output_dim, stride=2) + + self.conv2 = nn.Conv2d( + output_dim * 3 + output_dim // 4, + output_dim * 2, + kernel_size=3, + padding=1, + padding_mode="zeros", + ) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.InstanceNorm2d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + a = self.layer1(x) + b = self.layer2(a) + c = self.layer3(b) + d = self.layer4(c) + + a = _bilinear_intepolate(a, self.stride, H, W) + b = _bilinear_intepolate(b, self.stride, H, W) + c = _bilinear_intepolate(c, self.stride, H, W) + d = _bilinear_intepolate(d, self.stride, H, W) + + x = self.conv2(torch.cat([a, b, c, d], dim=1)) + x = self.norm2(x) + x = self.relu2(x) + x = self.conv3(x) + return x + + +class ShallowEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"): + super(ShallowEncoder, self).__init__() + self.stride = stride + self.norm_fn = norm_fn + self.in_planes = output_dim + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(self.in_planes) + self.norm2 = nn.BatchNorm2d(output_dim * 2) + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d( + input_dim, + self.in_planes, + kernel_size=3, + stride=2, + padding=1, + padding_mode="zeros", + ) + self.relu1 = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(output_dim, stride=2) + + self.layer2 = self._make_layer(output_dim, stride=2) + self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + self.in_planes = dim + + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + return layer1 + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + tmp = self.layer1(x) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = self.layer2(tmp) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = None + x = self.conv2(x) + x + + x = F.interpolate( + x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True + ) + + return x + + +def _bilinear_intepolate(x, stride, H, W): + return F.interpolate( + x, (H // stride, W // stride), mode="bilinear", align_corners=True + ) + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter( + torch.randn(1, num_virtual_tracks, 1, hidden_size) + ) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [ + CrossAttnBlock( + hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio + ) + for _ in range(space_depth) + ] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [ + CrossAttnBlock( + hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio + ) + for _ in range(space_depth) + ] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + torch.nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, input_tensor, mask=None): + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and ( + i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0 + ): + space_tokens = ( + tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) + ) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j]( + virtual_tokens, point_tokens, mask=mask + ) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j]( + point_tokens, virtual_tokens, mask=mask + ) + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute( + 0, 2, 1, 3 + ) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + flow = self.flow_head(tokens) + return flow + + +class CorrBlock: + def __init__( + self, + fmaps, + num_levels=4, + radius=4, + multiple_track_feats=False, + padding_mode="zeros", + ): + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.padding_mode = padding_mode + self.num_levels = num_levels + self.radius = radius + self.fmaps_pyramid = [] + self.multiple_track_feats = multiple_track_feats + + self.fmaps_pyramid.append(fmaps) + for i in range(self.num_levels - 1): + fmaps_ = fmaps.reshape(B * S, C, H, W) + fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) + _, _, H, W = fmaps_.shape + fmaps = fmaps_.reshape(B, S, C, H, W) + self.fmaps_pyramid.append(fmaps) + + def sample(self, coords): + r = self.radius + B, S, N, D = coords.shape + assert D == 2 + + H, W = self.H, self.W + out_pyramid = [] + for i in range(self.num_levels): + corrs = self.corrs_pyramid[i] # B, S, N, H, W + *_, H, W = corrs.shape + + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( + coords.device + ) + + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corrs = bilinear_sampler( + corrs.reshape(B * S * N, 1, H, W), + coords_lvl, + padding_mode=self.padding_mode, + ) + corrs = corrs.view(B, S, N, -1) + + out_pyramid.append(corrs) + + out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2 + return out + + def corr(self, targets): + B, S, N, C = targets.shape + if self.multiple_track_feats: + targets_split = targets.split(C // self.num_levels, dim=-1) + B, S, N, C = targets_split[0].shape + + assert C == self.C + assert S == self.S + + fmap1 = targets + + self.corrs_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + *_, H, W = fmaps.shape + fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) + if self.multiple_track_feats: + fmap1 = targets_split[i] + corrs = torch.matmul(fmap1, fmap2s) + corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W + corrs = corrs / torch.sqrt(torch.tensor(C).float()) + self.corrs_pyramid.append(corrs) diff --git a/mapanything/third_party/track_modules/modules.py b/mapanything/third_party/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..ee04b9695c407ed915c20c34f0e9480cb0aa0a1e --- /dev/null +++ b/mapanything/third_party/track_modules/modules.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + + +import collections +from functools import partial +from itertools import repeat +from typing import Callable + +import torch.nn as nn + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=kernel_size, + padding=1, + stride=stride, + padding_mode="zeros", + ) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros" + ) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.attn = attn_class( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__( + self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs + ): + """ + Cross attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/mapanything/third_party/track_modules/track_refine.py b/mapanything/third_party/track_modules/track_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..fab54edb622030ee31314e9a8c9017790b6dd177 --- /dev/null +++ b/mapanything/third_party/track_modules/track_refine.py @@ -0,0 +1,486 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +from typing import Tuple + +import torch +from einops import rearrange + + +def refine_track( + images, + fine_fnet, + fine_tracker, + coarse_pred, + compute_score=False, + pradius=15, + sradius=2, + fine_iters=6, + chunk=40960, +): + """ + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = ( + torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + ) + + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[ + batch_indices, :, topleft[..., 1], topleft[..., 0] + ] + + if chunk < 0: + # Extract image patches based on top left corners + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + else: + patches = extracted_patches.reshape(B * S * N, C_in, psize, psize) + + patch_feat_list = [] + for p in torch.split(patches, chunk): + patch_feat_list += [fine_fnet(p)] + patch_feat = torch.cat(patch_feat_list, 0) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, + fmaps=patch_feat, + iters=fine_iters, + return_feat=True, + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange( + fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N + ) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn( + query_point_feat, + patch_feat, + fine_pred_track, + sradius, + psize, + B, + N, + S, + C_out, + ) + + return refined_tracks, score + + +def refine_track_v0( + images, + fine_fnet, + fine_tracker, + coarse_pred, + compute_score=False, + pradius=15, + sradius=2, + fine_iters=6, +): + """ + COPIED FROM VGGSfM + + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = ( + torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + ) + + # Extract image patches based on top left corners + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[ + batch_indices, :, topleft[..., 1], topleft[..., 0] + ] + + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, + fmaps=patch_feat, + iters=fine_iters, + return_feat=True, + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange( + fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N + ) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn( + query_point_feat, + patch_feat, + fine_pred_track, + sradius, + psize, + B, + N, + S, + C_out, + ) + + return refined_tracks, score + + +################################## NOTE: NOT USED ################################## + + +def compute_score_fn( + query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out +): + """ + Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps, + given the query point features and reference frame feature maps + """ + + from kornia.geometry.subpix import dsnt + from kornia.utils.grid import create_meshgrid + + # query_point_feat initial shape: B x N x C_out, + # query_point_feat indicates the feat at the coorponsing query points + # Therefore we don't have S dimension here + query_point_feat = query_point_feat.reshape(B, N, C_out) + # reshape and expand to B x (S-1) x N x C_out + query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1) + # and reshape to (B*(S-1)*N) x C_out + query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out) + + # Radius and size for computing the score + ssize = sradius * 2 + 1 + + # Reshape, you know it, so many reshaping operations + patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N) + + # Again, we unfold the patches to smaller patches + # so that we can then focus on smaller patches + # patch_feat_unfold shape: + # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize + # well a bit scary, but actually not + patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1) + + # Do the same stuffs above, i.e., the same as extracting patches + fine_prediction_floor = fine_pred_track.floor().int() + fine_level_floor_topleft = fine_prediction_floor - sradius + + # Clamp to ensure the smaller patch is valid + fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize) + fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2) + + # Prepare the batch indices and xy locations + + batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN + batch_indices_score = batch_indices_score.reshape(-1).to( + patch_feat_unfold.device + ) # B*S*N + y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices + x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices + + reference_frame_feat = patch_feat_unfold.reshape( + B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize + ) + + # Note again, according to pytorch convention + # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0] + reference_frame_feat = reference_frame_feat[ + batch_indices_score, :, x_indices, y_indices + ] + reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize) + # pick the frames other than the first one, so we have S-1 frames here + reference_frame_feat = reference_frame_feat[:, 1:].reshape( + B * (S - 1) * N, C_out, ssize * ssize + ) + + # Compute similarity + sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat) + softmax_temp = 1.0 / C_out**0.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) + # 2D heatmaps + heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize + + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] + grid_normalized = create_meshgrid( + ssize, ssize, normalized_coordinates=True, device=heatmap.device + ).reshape(1, -1, 2) + + var = ( + torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) + - coords_normalized**2 + ) + std = torch.sum( + torch.sqrt(torch.clamp(var, min=1e-10)), -1 + ) # clamp needed for numerical stability + + score = std.reshape(B, S - 1, N) + # set score as 1 for the query frame + score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1) + + return score + + +def extract_glimpse( + tensor: torch.Tensor, + size: Tuple[int, int], + offsets, + mode="bilinear", + padding_mode="zeros", + debug=False, + orib=None, +): + B, C, W, H = tensor.shape + + h, w = size + xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0 + ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0 + + vy, vx = torch.meshgrid(ys, xs) + grid = torch.stack([vx, vy], dim=-1) # h, w, 2 + grid = grid[None] + + B, N, _ = offsets.shape + + offsets = offsets.reshape((B * N), 1, 1, 2) + offsets_grid = offsets + grid + + # normalised grid to [-1, 1] + offsets_grid = ( + offsets_grid - offsets_grid.new_tensor([W / 2, H / 2]) + ) / offsets_grid.new_tensor([W / 2, H / 2]) + + # BxCxHxW -> Bx1xCxHxW + tensor = tensor[:, None] + + # Bx1xCxHxW -> BxNxCxHxW + tensor = tensor.expand(-1, N, -1, -1, -1) + + # BxNxCxHxW -> (B*N)xCxHxW + tensor = tensor.reshape((B * N), C, W, H) + + sampled = torch.nn.functional.grid_sample( + tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode + ) + + # NOTE: I am not sure it should be h, w or w, h here + # but okay for sqaures + sampled = sampled.reshape(B, N, C, h, w) + + return sampled diff --git a/mapanything/third_party/track_modules/utils.py b/mapanything/third_party/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a69bdd27fabb9279c26ee6bd95fa2b6aa1137390 --- /dev/null +++ b/mapanything/third_party/track_modules/utils.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + + +from typing import Tuple, Union + +import torch +import torch.nn.functional as F + + +def get_2d_sincos_pos_embed( + embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False +) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return ( + pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), + grid, + ) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: torch.Tensor +) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: torch.Tensor +) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = ( + torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) + ).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + coords = coords.detach().clone() + ############################################################ + # IMPORTANT: + coords = coords.to(input.device).to(input.dtype) + ############################################################ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + scale = torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], + device=coords.device, + dtype=coords.dtype, + ) + else: + scale = torch.tensor( + [2 / size for size in reversed(sizes)], + device=coords.device, + dtype=coords.dtype, + ) + + coords.mul_(scale) # coords = coords * scale + coords.sub_(1) # coords = coords - 1 + + return F.grid_sample( + input, coords, align_corners=align_corners, padding_mode=padding_mode + ) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view( + B, -1, feats.shape[1] * feats.shape[3] + ) # B C R 1 -> B R C diff --git a/mapanything/third_party/track_predict.py b/mapanything/third_party/track_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..0054ec0a143452877af498e543966840241d7c58 --- /dev/null +++ b/mapanything/third_party/track_predict.py @@ -0,0 +1,353 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import numpy as np +import torch + +from .vggsfm_utils import ( + build_vggsfm_tracker, + calculate_index_mappings, + extract_keypoints, + generate_rank_by_dino, + initialize_feature_extractors, + predict_tracks_in_chunks, + switch_tensor_order, +) + + +def predict_tracks( + images, + conf=None, + points_3d=None, + max_query_pts=2048, + query_frame_num=5, + keypoint_extractor="aliked+sp", + max_points_num=163840, + fine_tracking=True, + complete_non_vis=True, +): + """ + Predict tracks for the given images and masks. + + TODO: support non-square images + TODO: support masks + + + This function predicts the tracks for the given images and masks using the specified query method + and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames. + + Args: + images: Tensor of shape [S, 3, H, W] containing the input images. + conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None. + points_3d: Tensor containing 3D points. Default is None. + max_query_pts: Maximum number of query points. Default is 2048. + query_frame_num: Number of query frames to use. Default is 5. + keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp". + max_points_num: Maximum number of points to process at once. Default is 163840. + fine_tracking: Whether to use fine tracking. Default is True. + complete_non_vis: Whether to augment non-visible frames. Default is True. + + Returns: + pred_tracks: Numpy array containing the predicted tracks. + pred_vis_scores: Numpy array containing the visibility scores for the tracks. + pred_confs: Numpy array containing the confidence scores for the tracks. + pred_points_3d: Numpy array containing the 3D points for the tracks. + pred_colors: Numpy array containing the point colors for the tracks. (0, 255) + """ + + device = images.device + dtype = images.dtype + tracker = build_vggsfm_tracker().to(device, dtype) + + # Find query frames + query_frame_indexes = generate_rank_by_dino( + images, query_frame_num=query_frame_num, device=device + ) + + # Add the first image to the front if not already present + if 0 in query_frame_indexes: + query_frame_indexes.remove(0) + query_frame_indexes = [0, *query_frame_indexes] + + # TODO: add the functionality to handle the masks + keypoint_extractors = initialize_feature_extractors( + max_query_pts, extractor_method=keypoint_extractor, device=device + ) + + pred_tracks = [] + pred_vis_scores = [] + pred_confs = [] + pred_points_3d = [] + pred_colors = [] + + fmaps_for_tracker = tracker.process_images_to_fmaps(images) + + if fine_tracking: + print("For faster inference, consider disabling fine_tracking") + + for query_index in query_frame_indexes: + print(f"Predicting tracks for query frame {query_index}") + pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + device, + ) + + pred_tracks.append(pred_track) + pred_vis_scores.append(pred_vis) + pred_confs.append(pred_conf) + pred_points_3d.append(pred_point_3d) + pred_colors.append(pred_color) + + if complete_non_vis: + pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = ( + _augment_non_visible_frames( + pred_tracks, + pred_vis_scores, + pred_confs, + pred_points_3d, + pred_colors, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + min_vis=500, + non_vis_thresh=0.1, + device=device, + ) + ) + + pred_tracks = np.concatenate(pred_tracks, axis=1) + pred_vis_scores = np.concatenate(pred_vis_scores, axis=1) + pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None + pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None + pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None + + # from vggt.utils.visual_track import visualize_tracks_on_images + # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals") + + return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors + + +def _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + device, +): + """ + Process a single query frame for track prediction. + + Args: + query_index: Index of the query frame + images: Tensor of shape [S, 3, H, W] containing the input images + conf: Confidence tensor + points_3d: 3D points tensor + fmaps_for_tracker: Feature maps for the tracker + keypoint_extractors: Initialized feature extractors + tracker: VGG-SFM tracker + max_points_num: Maximum number of points to process at once + fine_tracking: Whether to use fine tracking + device: Device to use for computation + + Returns: + pred_track: Predicted tracks + pred_vis: Visibility scores for the tracks + pred_conf: Confidence scores for the tracks + pred_point_3d: 3D points for the tracks + pred_color: Point colors for the tracks (0, 255) + """ + frame_num, _, height, width = images.shape + + query_image = images[query_index] + query_points = extract_keypoints( + query_image, keypoint_extractors, round_keypoints=False + ) + query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] + + # Extract the color at the keypoint locations + query_points_long = query_points.squeeze(0).round().long() + pred_color = images[query_index][ + :, query_points_long[:, 1], query_points_long[:, 0] + ] + pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8) + + # Query the confidence and points_3d at the keypoint locations + if (conf is not None) and (points_3d is not None): + assert height == width + assert conf.shape[-2] == conf.shape[-1] + assert conf.shape[:3] == points_3d.shape[:3] + scale = conf.shape[-1] / width + + query_points_scaled = (query_points.squeeze(0) * scale).round().long() + query_points_scaled = query_points_scaled.cpu().numpy() + + pred_conf = conf[query_index][ + query_points_scaled[:, 1], query_points_scaled[:, 0] + ] + pred_point_3d = points_3d[query_index][ + query_points_scaled[:, 1], query_points_scaled[:, 0] + ] + + # heuristic to remove low confidence points + # should I export this as an input parameter? + valid_mask = pred_conf > 1.2 + if valid_mask.sum() > 512: + query_points = query_points[:, valid_mask] # Make sure shape is compatible + pred_conf = pred_conf[valid_mask] + pred_point_3d = pred_point_3d[valid_mask] + pred_color = pred_color[valid_mask] + else: + pred_conf = None + pred_point_3d = None + + reorder_index = calculate_index_mappings(query_index, frame_num, device=device) + + images_feed, fmaps_feed = switch_tensor_order( + [images, fmaps_for_tracker], reorder_index, dim=0 + ) + images_feed = images_feed[None] # add batch dimension + fmaps_feed = fmaps_feed[None] # add batch dimension + + all_points_num = images_feed.shape[1] * query_points.shape[1] + + # Don't need to be scared, this is just chunking to make GPU happy + if all_points_num > max_points_num: + num_splits = (all_points_num + max_points_num - 1) // max_points_num + query_points = torch.chunk(query_points, num_splits, dim=1) + else: + query_points = [query_points] + + pred_track, pred_vis, _ = predict_tracks_in_chunks( + tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking + ) + + pred_track, pred_vis = switch_tensor_order( + [pred_track, pred_vis], reorder_index, dim=1 + ) + + pred_track = pred_track.squeeze(0).float().cpu().numpy() + pred_vis = pred_vis.squeeze(0).float().cpu().numpy() + + return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color + + +def _augment_non_visible_frames( + pred_tracks: list, # ← running list of np.ndarrays + pred_vis_scores: list, # ← running list of np.ndarrays + pred_confs: list, # ← running list of np.ndarrays for confidence scores + pred_points_3d: list, # ← running list of np.ndarrays for 3D points + pred_colors: list, # ← running list of np.ndarrays for colors + images: torch.Tensor, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num: int, + fine_tracking: bool, + *, + min_vis: int = 500, + non_vis_thresh: float = 0.1, + device: torch.device = None, +): + """ + Augment tracking for frames with insufficient visibility. + + Args: + pred_tracks: List of numpy arrays containing predicted tracks. + pred_vis_scores: List of numpy arrays containing visibility scores. + pred_confs: List of numpy arrays containing confidence scores. + pred_points_3d: List of numpy arrays containing 3D points. + pred_colors: List of numpy arrays containing point colors. + images: Tensor of shape [S, 3, H, W] containing the input images. + conf: Tensor of shape [S, 1, H, W] containing confidence scores + points_3d: Tensor containing 3D points + fmaps_for_tracker: Feature maps for the tracker + keypoint_extractors: Initialized feature extractors + tracker: VGG-SFM tracker + max_points_num: Maximum number of points to process at once + fine_tracking: Whether to use fine tracking + min_vis: Minimum visibility threshold + non_vis_thresh: Non-visibility threshold + device: Device to use for computation + + Returns: + Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists. + """ + last_query = -1 + final_trial = False + cur_extractors = keypoint_extractors # may be replaced on the final trial + + while True: + # Visibility per frame + vis_array = np.concatenate(pred_vis_scores, axis=1) + + # Count frames with sufficient visibility using numpy + sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1) + non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist() + + if len(non_vis_frames) == 0: + break + + print("Processing non visible frames:", non_vis_frames) + + # Decide the frames & extractor for this round + if non_vis_frames[0] == last_query: + # Same frame failed twice - final "all-in" attempt + final_trial = True + cur_extractors = initialize_feature_extractors( + 2048, extractor_method="sp+sift+aliked", device=device + ) + query_frame_list = non_vis_frames # blast them all at once + else: + query_frame_list = [non_vis_frames[0]] # Process one at a time + + last_query = non_vis_frames[0] + + # Run the tracker for every selected frame + for query_index in query_frame_list: + new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + cur_extractors, + tracker, + max_points_num, + fine_tracking, + device, + ) + pred_tracks.append(new_track) + pred_vis_scores.append(new_vis) + pred_confs.append(new_conf) + pred_points_3d.append(new_point_3d) + pred_colors.append(new_color) + + if final_trial: + break # Stop after final attempt + + return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors diff --git a/mapanything/third_party/vggsfm_tracker.py b/mapanything/third_party/vggsfm_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..360aec9a62a1ba44c8b7ddbcc8e8e78ea0a05e97 --- /dev/null +++ b/mapanything/third_party/vggsfm_tracker.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .track_modules.base_track_predictor import BaseTrackerPredictor +from .track_modules.blocks import BasicEncoder, ShallowEncoder +from .track_modules.track_refine import refine_track + + +class TrackerPredictor(nn.Module): + def __init__(self, **extra_args): + super(TrackerPredictor, self).__init__() + """ + Initializes the tracker predictor. + + Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor, + check track_modules/base_track_predictor.py + + Both coarse_fnet and fine_fnet are constructed as a 2D CNN network + check track_modules/blocks.py for BasicEncoder and ShallowEncoder + """ + # Define coarse predictor configuration + coarse_stride = 4 + self.coarse_down_ratio = 2 + + # Create networks directly instead of using instantiate + self.coarse_fnet = BasicEncoder(stride=coarse_stride) + self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride) + + # Create fine predictor with stride = 1 + self.fine_fnet = ShallowEncoder(stride=1) + self.fine_predictor = BaseTrackerPredictor( + stride=1, + depth=4, + corr_levels=3, + corr_radius=3, + latent_dim=32, + hidden_size=256, + fine=True, + use_spaceatt=False, + ) + + def forward( + self, + images, + query_points, + fmaps=None, + coarse_iters=6, + inference=True, + fine_tracking=True, + fine_chunk=40960, + ): + """ + Args: + images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W. + query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2. + fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None. + coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6. + inference (bool, optional): Whether to perform inference. Defaults to True. + fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True. + + Returns: + tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score. + """ + + if fmaps is None: + batch_num, frame_num, image_dim, height, width = images.shape + reshaped_image = images.reshape( + batch_num * frame_num, image_dim, height, width + ) + fmaps = self.process_images_to_fmaps(reshaped_image) + fmaps = fmaps.reshape( + batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1] + ) + + if inference: + torch.cuda.empty_cache() + + # Coarse prediction + coarse_pred_track_lists, pred_vis = self.coarse_predictor( + query_points=query_points, + fmaps=fmaps, + iters=coarse_iters, + down_ratio=self.coarse_down_ratio, + ) + coarse_pred_track = coarse_pred_track_lists[-1] + + if inference: + torch.cuda.empty_cache() + + if fine_tracking: + # Refine the coarse prediction + fine_pred_track, pred_score = refine_track( + images, + self.fine_fnet, + self.fine_predictor, + coarse_pred_track, + compute_score=False, + chunk=fine_chunk, + ) + + if inference: + torch.cuda.empty_cache() + else: + fine_pred_track = coarse_pred_track + pred_score = torch.ones_like(pred_vis) + + return fine_pred_track, coarse_pred_track, pred_vis, pred_score + + def process_images_to_fmaps(self, images): + """ + This function processes images for inference. + + Args: + images (torch.Tensor): The images to be processed with shape S x 3 x H x W. + + Returns: + torch.Tensor: The processed feature maps. + """ + if self.coarse_down_ratio > 1: + # whether or not scale down the input images to save memory + fmaps = self.coarse_fnet( + F.interpolate( + images, + scale_factor=1 / self.coarse_down_ratio, + mode="bilinear", + align_corners=True, + ) + ) + else: + fmaps = self.coarse_fnet(images) + + return fmaps diff --git a/mapanything/third_party/vggsfm_utils.py b/mapanything/third_party/vggsfm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c598ae40c41fb8ad10e14e401352d6342dbfd06 --- /dev/null +++ b/mapanything/third_party/vggsfm_utils.py @@ -0,0 +1,340 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import logging +import warnings + +import torch +import torch.nn.functional as F +from lightglue import ALIKED, SIFT, SuperPoint + +from .vggsfm_tracker import TrackerPredictor + +# Suppress verbose logging from dependencies +logging.getLogger("dinov2").setLevel(logging.WARNING) +warnings.filterwarnings("ignore", message="xFormers is available") +warnings.filterwarnings("ignore", message="dinov2") + +# Constants +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +def build_vggsfm_tracker(model_path=None): + """ + Build and initialize the VGGSfM tracker. + + Args: + model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace. + + Returns: + Initialized tracker model in eval mode. + """ + tracker = TrackerPredictor() + + if model_path is None: + default_url = ( + "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt" + ) + tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url)) + else: + tracker.load_state_dict(torch.load(model_path)) + + tracker.eval() + return tracker + + +def generate_rank_by_dino( + images, + query_frame_num, + image_size=336, + model_name="dinov2_vitb14_reg", + device="cuda", + spatial_similarity=False, +): + """ + Generate a ranking of frames using DINO ViT features. + + Args: + images: Tensor of shape (S, 3, H, W) with values in range [0, 1] + query_frame_num: Number of frames to select + image_size: Size to resize images to before processing + model_name: Name of the DINO model to use + device: Device to run the model on + spatial_similarity: Whether to use spatial token similarity or CLS token similarity + + Returns: + List of frame indices ranked by their representativeness + """ + # Resize images to the target size + images = F.interpolate( + images, (image_size, image_size), mode="bilinear", align_corners=False + ) + + # Load DINO model + dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) + dino_v2_model.eval() + dino_v2_model = dino_v2_model.to(device) + + # Normalize images using ResNet normalization + resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) + resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) + images_resnet_norm = (images - resnet_mean) / resnet_std + + with torch.no_grad(): + frame_feat = dino_v2_model(images_resnet_norm, is_training=True) + + # Process features based on similarity type + if spatial_similarity: + frame_feat = frame_feat["x_norm_patchtokens"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + + # Compute the similarity matrix + frame_feat_norm = frame_feat_norm.permute(1, 0, 2) + similarity_matrix = torch.bmm( + frame_feat_norm, frame_feat_norm.transpose(-1, -2) + ) + similarity_matrix = similarity_matrix.mean(dim=0) + else: + frame_feat = frame_feat["x_norm_clstoken"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + + distance_matrix = 100 - similarity_matrix.clone() + + # Ignore self-pairing + similarity_matrix.fill_diagonal_(-100) + similarity_sum = similarity_matrix.sum(dim=1) + + # Find the most common frame + most_common_frame_index = torch.argmax(similarity_sum).item() + + # Conduct FPS sampling starting from the most common frame + fps_idx = farthest_point_sampling( + distance_matrix, query_frame_num, most_common_frame_index + ) + + # Clean up all tensors and models to free memory + del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix + del dino_v2_model + torch.cuda.empty_cache() + + return fps_idx + + +def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): + """ + Farthest point sampling algorithm to select diverse frames. + + Args: + distance_matrix: Matrix of distances between frames + num_samples: Number of frames to select + most_common_frame_index: Index of the first frame to select + + Returns: + List of selected frame indices + """ + distance_matrix = distance_matrix.clamp(min=0) + N = distance_matrix.size(0) + + # Initialize with the most common frame + selected_indices = [most_common_frame_index] + check_distances = distance_matrix[selected_indices] + + while len(selected_indices) < num_samples: + # Find the farthest point from the current set of selected points + farthest_point = torch.argmax(check_distances) + selected_indices.append(farthest_point.item()) + + check_distances = distance_matrix[farthest_point] + # Mark already selected points to avoid selecting them again + check_distances[selected_indices] = 0 + + # Break if all points have been selected + if len(selected_indices) == N: + break + + return selected_indices + + +def calculate_index_mappings(query_index, S, device=None): + """ + Construct an order that switches [query_index] and [0] + so that the content of query_index would be placed at [0]. + + Args: + query_index: Index to swap with 0 + S: Total number of elements + device: Device to place the tensor on + + Returns: + Tensor of indices with the swapped order + """ + new_order = torch.arange(S) + new_order[0] = query_index + new_order[query_index] = 0 + if device is not None: + new_order = new_order.to(device) + return new_order + + +def switch_tensor_order(tensors, order, dim=1): + """ + Reorder tensors along a specific dimension according to the given order. + + Args: + tensors: List of tensors to reorder + order: Tensor of indices specifying the new order + dim: Dimension along which to reorder + + Returns: + List of reordered tensors + """ + return [ + torch.index_select(tensor, dim, order) if tensor is not None else None + for tensor in tensors + ] + + +def initialize_feature_extractors( + max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda" +): + """ + Initialize feature extractors that can be reused based on a method string. + + Args: + max_query_num: Maximum number of keypoints to extract + det_thres: Detection threshold for keypoint extraction + extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") + device: Device to run extraction on + + Returns: + Dictionary of initialized extractors + """ + extractors = {} + methods = extractor_method.lower().split("+") + + for method in methods: + method = method.strip() + if method == "aliked": + aliked_extractor = ALIKED( + max_num_keypoints=max_query_num, detection_threshold=det_thres + ) + extractors["aliked"] = aliked_extractor.to(device).eval() + elif method == "sp": + sp_extractor = SuperPoint( + max_num_keypoints=max_query_num, detection_threshold=det_thres + ) + extractors["sp"] = sp_extractor.to(device).eval() + elif method == "sift": + sift_extractor = SIFT(max_num_keypoints=max_query_num) + extractors["sift"] = sift_extractor.to(device).eval() + else: + print(f"Warning: Unknown feature extractor '{method}', ignoring.") + + if not extractors: + print( + f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default." + ) + aliked_extractor = ALIKED( + max_num_keypoints=max_query_num, detection_threshold=det_thres + ) + extractors["aliked"] = aliked_extractor.to(device).eval() + + return extractors + + +def extract_keypoints(query_image, extractors, round_keypoints=True): + """ + Extract keypoints using pre-initialized feature extractors. + + Args: + query_image: Input image tensor (3xHxW, range [0, 1]) + extractors: Dictionary of initialized extractors + + Returns: + Tensor of keypoint coordinates (1xNx2) + """ + query_points = None + + with torch.no_grad(): + for extractor_name, extractor in extractors.items(): + query_points_data = extractor.extract(query_image, invalid_mask=None) + extractor_points = query_points_data["keypoints"] + if round_keypoints: + extractor_points = extractor_points.round() + + if query_points is not None: + query_points = torch.cat([query_points, extractor_points], dim=1) + else: + query_points = extractor_points + + return query_points + + +def predict_tracks_in_chunks( + track_predictor, + images_feed, + query_points_list, + fmaps_feed, + fine_tracking, + num_splits=None, + fine_chunk=40960, +): + """ + Process a list of query points to avoid memory issues. + + Args: + track_predictor (object): The track predictor object used for predicting tracks. + images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images. + query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points. + fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker. + fine_tracking (bool): Whether to perform fine tracking. + num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility. + + Returns: + tuple: A tuple containing the concatenated predicted tracks, visibility, and scores. + """ + # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility + if not isinstance(query_points_list, (list, tuple)): + query_points = query_points_list + if num_splits is None: + num_splits = 1 + query_points_list = torch.chunk(query_points, num_splits, dim=1) + + # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple) + if isinstance(query_points_list, tuple): + query_points_list = list(query_points_list) + + fine_pred_track_list = [] + pred_vis_list = [] + pred_score_list = [] + + for split_points in query_points_list: + # Feed into track predictor for each split + fine_pred_track, _, pred_vis, pred_score = track_predictor( + images_feed, + split_points, + fmaps=fmaps_feed, + fine_tracking=fine_tracking, + fine_chunk=fine_chunk, + ) + fine_pred_track_list.append(fine_pred_track) + pred_vis_list.append(pred_vis) + pred_score_list.append(pred_score) + + # Concatenate the results from all splits + fine_pred_track = torch.cat(fine_pred_track_list, dim=2) + pred_vis = torch.cat(pred_vis_list, dim=2) + + if pred_score is not None: + pred_score = torch.cat(pred_score_list, dim=2) + else: + pred_score = None + + return fine_pred_track, pred_vis, pred_score diff --git a/mapanything/train/__init__.py b/mapanything/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/train/losses.py b/mapanything/train/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..b5203bc7531e8348b4de2fbffc7286c04b643861 --- /dev/null +++ b/mapanything/train/losses.py @@ -0,0 +1,5068 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Multi-view geometric losses for training 3D reconstruction models. + +References: DUSt3R & MASt3R +""" + +import math +from copy import copy, deepcopy + +import einops as ein +import torch +import torch.nn as nn + +from mapanything.utils.geometry import ( + angle_diff_vec3, + apply_log_to_norm, + closed_form_pose_inverse, + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + geotrf, + normalize_multiple_pointclouds, + quaternion_inverse, + quaternion_multiply, + quaternion_to_rotation_matrix, + transform_pose_using_quats_and_trans_2_to_1, +) + + +def get_loss_terms_and_details( + losses_dict, valid_masks, self_name, n_views, flatten_across_image_only +): + """ + Helper function to generate loss terms and details for different loss types. + + Args: + losses_dict (dict): Dictionary mapping loss types to their values. + Format: { + 'loss_type': { + 'values': list_of_loss_tensors or single_tensor, + 'use_mask': bool, + 'is_multi_view': bool + } + } + valid_masks (list): List of valid masks for each view. + self_name (str): Name of the loss class. + n_views (int): Number of views. + flatten_across_image_only (bool): Whether flattening was done across image only. + + Returns: + tuple: (loss_terms, details) where loss_terms is a list of tuples (loss, mask, type) + and details is a dictionary of loss details. + """ + loss_terms = [] + details = {} + + for loss_type, loss_info in losses_dict.items(): + values = loss_info["values"] + use_mask = loss_info["use_mask"] + is_multi_view = loss_info["is_multi_view"] + if is_multi_view: + # Handle multi-view losses (list of tensors) + view_loss_details = [] + for i in range(n_views): + mask = valid_masks[i] if use_mask else None + loss_terms.append((values[i], mask, loss_type)) + + # Add details for individual view + if not flatten_across_image_only or not use_mask: + values_after_masking = values[i] + else: + values_after_masking = values[i][mask] + + if values_after_masking.numel() > 0: + view_loss_detail = float(values_after_masking.mean()) + if view_loss_detail > 0: + details[f"{self_name}_{loss_type}_view{i + 1}"] = ( + view_loss_detail + ) + view_loss_details.append(view_loss_detail) + # Add average across views + if len(view_loss_details) > 0: + details[f"{self_name}_{loss_type}_avg"] = sum(view_loss_details) / len( + view_loss_details + ) + else: + # Handle single tensor losses + if values is not None: + loss_terms.append((values, None, loss_type)) + if values.numel() > 0: + loss_detail = float(values.mean()) + if loss_detail > 0: + details[f"{self_name}_{loss_type}"] = loss_detail + + return loss_terms, details + + +def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor: + if beta == 0: + return err + else: + return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta) + + +def compute_normal_loss(points, gt_points, mask): + """ + Compute the normal loss between the predicted and ground truth points. + References: + https://github.com/microsoft/MoGe/blob/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/train/losses.py#L205 + + Args: + points (torch.Tensor): Predicted points. Shape: (..., H, W, 3). + gt_points (torch.Tensor): Ground truth points. Shape: (..., H, W, 3). + mask (torch.Tensor): Mask indicating valid points. Shape: (..., H, W). + + Returns: + torch.Tensor: Normal loss. + """ + height, width = points.shape[-3:-1] + + leftup, rightup, leftdown, rightdown = ( + points[..., :-1, :-1, :], + points[..., :-1, 1:, :], + points[..., 1:, :-1, :], + points[..., 1:, 1:, :], + ) + upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1) + leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1) + downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1) + rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1) + + gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = ( + gt_points[..., :-1, :-1, :], + gt_points[..., :-1, 1:, :], + gt_points[..., 1:, :-1, :], + gt_points[..., 1:, 1:, :], + ) + gt_upxleft = torch.cross( + gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1 + ) + gt_leftxdown = torch.cross( + gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1 + ) + gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1) + gt_rightxup = torch.cross( + gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1 + ) + + mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = ( + mask[..., :-1, :-1], + mask[..., :-1, 1:], + mask[..., 1:, :-1], + mask[..., 1:, 1:], + ) + mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown + mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup + mask_downxright = mask_leftdown & mask_rightup & mask_leftup + mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown + + MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3) + + loss = ( + mask_upxleft + * _smooth( + angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE), + beta=BETA_RAD, + ) + + mask_leftxdown + * _smooth( + angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE), + beta=BETA_RAD, + ) + + mask_downxright + * _smooth( + angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE), + beta=BETA_RAD, + ) + + mask_rightxup + * _smooth( + angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE), + beta=BETA_RAD, + ) + ) + + total_valid_mask = mask_upxleft | mask_leftxdown | mask_downxright | mask_rightxup + valid_count = total_valid_mask.sum() + if valid_count > 0: + loss = loss.sum() / (valid_count * (4 * max(points.shape[-3:-1]))) + else: + loss = 0 * loss.sum() + + return loss + + +def compute_gradient_loss(prediction, gt_target, mask): + """ + Compute the gradient loss between the prediction and GT target at valid points. + References: + https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss + https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py + + Args: + prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C). + gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C). + mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W). + """ + # Expand mask to match number of channels in prediction + mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1]) + summed_mask = torch.sum(mask, (1, 2, 3)) + + # Compute the gradient of the prediction and GT target + diff = prediction - gt_target + diff = torch.mul(mask, diff) + + # Gradient in x direction + grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) + mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) + grad_x = torch.mul(mask_x, grad_x) + + # Gradient in y direction + grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) + mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) + grad_y = torch.mul(mask_y, grad_y) + + # Clamp the outlier gradients + grad_x = grad_x.clamp(max=100) + grad_y = grad_y.clamp(max=100) + + # Compute the total loss + image_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3)) + num_valid_pixels = torch.sum(summed_mask) + if num_valid_pixels > 0: + image_loss = torch.sum(image_loss) / num_valid_pixels + else: + image_loss = 0 * torch.sum(image_loss) + + return image_loss + + +def compute_gradient_matching_loss(prediction, gt_target, mask, scales=4): + """ + Compute the multi-scale gradient matching loss between the prediction and GT target at valid points. + This loss biases discontinuities to be sharp and to coincide with discontinuities in the ground truth. + More info in MiDAS: https://arxiv.org/pdf/1907.01341.pdf; Equation 11 + References: + https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss + https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py + + Args: + prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C). + gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C). + mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W). + scales (int): Number of scales to compute the loss at. Default: 4. + """ + # Define total loss + total_loss = 0.0 + + # Compute the gradient loss at different scales + for scale in range(scales): + step = pow(2, scale) + grad_loss = compute_gradient_loss( + prediction[:, ::step, ::step], + gt_target[:, ::step, ::step], + mask[:, ::step, ::step], + ) + total_loss += grad_loss + + return total_loss + + +def Sum(*losses_and_masks): + """ + Aggregates multiple losses into a single loss value or returns the original losses. + + Args: + *losses_and_masks: Variable number of tuples, each containing (loss, mask, rep_type) + - loss: Tensor containing loss values + - mask: Mask indicating valid pixels/regions + - rep_type: String indicating the type of representation (e.g., 'pts3d', 'depth') + + Returns: + If the first loss has dimensions > 0: + Returns the original list of (loss, mask, rep_type) tuples + Otherwise: + Returns a scalar tensor that is the sum of all loss values + """ + loss, mask, rep_type = losses_and_masks[0] + if loss.ndim > 0: + # we are actually returning the loss for every pixels + return losses_and_masks + else: + # we are returning the global loss + for loss2, mask2, rep_type2 in losses_and_masks[1:]: + loss = loss + loss2 + return loss + + +class BaseCriterion(nn.Module): + "Base Criterion to support different reduction methods" + + def __init__(self, reduction="mean"): + super().__init__() + self.reduction = reduction + + +class LLoss(BaseCriterion): + "L-norm loss" + + def forward(self, a, b, **kwargs): + assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 4, ( + f"Bad shape = {a.shape}" + ) + dist = self.distance(a, b, **kwargs) + assert dist.ndim == a.ndim - 1 # one dimension less + if self.reduction == "none": + return dist + if self.reduction == "sum": + return dist.sum() + if self.reduction == "mean": + return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) + raise ValueError(f"bad {self.reduction=} mode") + + def distance(self, a, b, **kwargs): + raise NotImplementedError() + + +class L1Loss(LLoss): + "L1 distance" + + def distance(self, a, b, **kwargs): + return torch.abs(a - b).sum(dim=-1) + + +class L2Loss(LLoss): + "Euclidean (L2 Norm) distance" + + def distance(self, a, b, **kwargs): + return torch.norm(a - b, dim=-1) + + +class GenericLLoss(LLoss): + "Criterion that supports different L-norms" + + def distance(self, a, b, loss_type, **kwargs): + if loss_type == "l1": + # L1 distance + return torch.abs(a - b).sum(dim=-1) + elif loss_type == "l2": + # Euclidean (L2 norm) distance + return torch.norm(a - b, dim=-1) + else: + raise ValueError( + f"Unsupported loss type: {loss_type}. Supported types are 'l1' and 'l2'." + ) + + +class FactoredLLoss(LLoss): + "Criterion that supports different L-norms for the factored loss functions" + + def __init__( + self, + reduction="mean", + points_loss_type="l2", + depth_loss_type="l1", + ray_directions_loss_type="l1", + pose_quats_loss_type="l1", + pose_trans_loss_type="l1", + scale_loss_type="l1", + ): + super().__init__(reduction) + self.points_loss_type = points_loss_type + self.depth_loss_type = depth_loss_type + self.ray_directions_loss_type = ray_directions_loss_type + self.pose_quats_loss_type = pose_quats_loss_type + self.pose_trans_loss_type = pose_trans_loss_type + self.scale_loss_type = scale_loss_type + + def _distance(self, a, b, loss_type): + if loss_type == "l1": + # L1 distance + return torch.abs(a - b).sum(dim=-1) + elif loss_type == "l2": + # Euclidean (L2 norm) distance + return torch.norm(a - b, dim=-1) + else: + raise ValueError(f"Unsupported loss type: {loss_type}.") + + def distance(self, a, b, factor, **kwargs): + if factor == "points": + return self._distance(a, b, self.points_loss_type) + elif factor == "depth": + return self._distance(a, b, self.depth_loss_type) + elif factor == "ray_directions": + return self._distance(a, b, self.ray_directions_loss_type) + elif factor == "pose_quats": + return self._distance(a, b, self.pose_quats_loss_type) + elif factor == "pose_trans": + return self._distance(a, b, self.pose_trans_loss_type) + elif factor == "scale": + return self._distance(a, b, self.scale_loss_type) + else: + raise ValueError(f"Unsupported factor type: {factor}.") + + +class RobustRegressionLoss(LLoss): + """ + Generalized Robust Loss introduced in https://arxiv.org/abs/1701.03077. + """ + + def __init__(self, alpha=0.5, scaling_c=0.25, reduction="mean"): + """ + Initialize the Robust Regression Loss. + + Args: + alpha (float): Shape parameter controlling the robustness of the loss. + Lower values make the loss more robust to outliers. Default: 0.5. + scaling_c (float): Scale parameter controlling the transition between + quadratic and robust behavior. Default: 0.1. + reduction (str): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + super().__init__(reduction) + self.alpha = alpha + self.scaling_c = scaling_c + + def distance(self, a, b, **kwargs): + error_scaled = torch.sum(((a - b) / self.scaling_c) ** 2, dim=-1) + robust_loss = (abs(self.alpha - 2) / self.alpha) * ( + torch.pow((error_scaled / abs(self.alpha - 2)) + 1, self.alpha / 2) - 1 + ) + return robust_loss + + +class BCELoss(BaseCriterion): + """Binary Cross Entropy loss""" + + def forward(self, predicted_logits, reference_mask): + """ + Args: + predicted_logits: (B, H, W) tensor of predicted logits for the mask + reference_mask: (B, H, W) tensor of reference mask + + Returns: + loss: scalar tensor of the BCE loss + """ + bce_loss = torch.nn.functional.binary_cross_entropy_with_logits( + predicted_logits, reference_mask.float() + ) + + return bce_loss + + +class Criterion(nn.Module): + """ + Base class for all criterion modules that wrap a BaseCriterion. + + This class serves as a wrapper around BaseCriterion objects, providing + additional functionality like naming and reduction mode control. + + Args: + criterion (BaseCriterion): The base criterion to wrap. + """ + + def __init__(self, criterion=None): + super().__init__() + assert isinstance(criterion, BaseCriterion), ( + f"{criterion} is not a proper criterion!" + ) + self.criterion = copy(criterion) + + def get_name(self): + """ + Returns a string representation of this criterion. + + Returns: + str: A string containing the class name and the wrapped criterion. + """ + return f"{type(self).__name__}({self.criterion})" + + def with_reduction(self, mode="none"): + """ + Creates a deep copy of this criterion with the specified reduction mode. + + This method recursively sets the reduction mode for this criterion and + any chained MultiLoss criteria. + + Args: + mode (str): The reduction mode to set. Default: "none". + + Returns: + Criterion: A new criterion with the specified reduction mode. + """ + res = loss = deepcopy(self) + while loss is not None: + assert isinstance(loss, Criterion) + loss.criterion.reduction = mode # make it return the loss for each sample + loss = loss._loss2 # we assume loss is a Multiloss + return res + + +class MultiLoss(nn.Module): + """ + Base class for combinable loss functions with automatic tracking of individual loss values. + + This class enables easy combination of multiple loss functions through arithmetic operations: + loss = MyLoss1() + 0.1*MyLoss2() + + The combined loss functions maintain their individual weights and the forward pass + automatically computes and aggregates all losses while tracking individual loss values. + + Usage: + Inherit from this class and override get_name() and compute_loss() methods. + + Attributes: + _alpha (float): Weight multiplier for this loss component. + _loss2 (MultiLoss): Reference to the next loss in the chain, if any. + """ + + def __init__(self): + """Initialize the MultiLoss with default weight of 1 and no chained loss.""" + super().__init__() + self._alpha = 1 + self._loss2 = None + + def compute_loss(self, *args, **kwargs): + """ + Compute the loss value for this specific loss component. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + torch.Tensor or tuple: Either the loss tensor or a tuple of (loss, details_dict). + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() + + def get_name(self): + """ + Get the name of this loss component. + + Returns: + str: The name of the loss. + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() + + def __mul__(self, alpha): + """ + Multiply the loss by a scalar weight. + + Args: + alpha (int or float): The weight to multiply the loss by. + + Returns: + MultiLoss: A new loss object with the updated weight. + + Raises: + AssertionError: If alpha is not a number. + """ + assert isinstance(alpha, (int, float)) + res = copy(self) + res._alpha = alpha + return res + + __rmul__ = __mul__ # Support both loss*alpha and alpha*loss + + def __add__(self, loss2): + """ + Add another loss to this loss, creating a chain of losses. + + Args: + loss2 (MultiLoss): Another loss to add to this one. + + Returns: + MultiLoss: A new loss object representing the combined losses. + + Raises: + AssertionError: If loss2 is not a MultiLoss. + """ + assert isinstance(loss2, MultiLoss) + res = cur = copy(self) + # Find the end of the chain + while cur._loss2 is not None: + cur = cur._loss2 + cur._loss2 = loss2 + return res + + def __repr__(self): + """ + Create a string representation of the loss, including weights and chained losses. + + Returns: + str: String representation of the loss. + """ + name = self.get_name() + if self._alpha != 1: + name = f"{self._alpha:g}*{name}" + if self._loss2: + name = f"{name} + {self._loss2}" + return name + + def forward(self, *args, **kwargs): + """ + Compute the weighted loss and aggregate with any chained losses. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + tuple: A tuple containing: + - torch.Tensor: The total weighted loss. + - dict: Details about individual loss components. + """ + loss = self.compute_loss(*args, **kwargs) + if isinstance(loss, tuple): + loss, details = loss + elif loss.ndim == 0: + details = {self.get_name(): float(loss)} + else: + details = {} + loss = loss * self._alpha + + if self._loss2: + loss2, details2 = self._loss2(*args, **kwargs) + loss = loss + loss2 + details |= details2 + + return loss, details + + +class NonAmbiguousMaskLoss(Criterion, MultiLoss): + """ + Loss on non-ambiguous mask prediction logits. + """ + + def __init__(self, criterion): + super().__init__(criterion) + + def compute_loss(self, batch, preds, **kw): + """ + Args: + batch: list of dicts with the gt data + preds: list of dicts with the predictions + + Returns: + loss: Sum class of the lossses for N-views and the loss details + """ + # Init loss list to keep track of individual losses for each view + loss_list = [] + mask_loss_details = {} + mask_loss_total = 0 + self_name = type(self).__name__ + + # Loop over the views + for view_idx, (gt, pred) in enumerate(zip(batch, preds)): + # Get the GT non-ambiguous masks + gt_non_ambiguous_mask = gt["non_ambiguous_mask"] + + # Get the predicted non-ambiguous mask logits + pred_non_ambiguous_mask_logits = pred["non_ambiguous_mask_logits"] + + # Compute the loss for the current view + loss = self.criterion(pred_non_ambiguous_mask_logits, gt_non_ambiguous_mask) + + # Add the loss to the list + loss_list.append((loss, None, "non_ambiguous_mask")) + + # Add the loss details to the dictionary + mask_loss_details[f"{self_name}_mask_view{view_idx + 1}"] = float(loss) + mask_loss_total += float(loss) + + # Compute the average loss across all views + mask_loss_details[f"{self_name}_mask_avg"] = mask_loss_total / len(batch) + + return Sum(*loss_list), (mask_loss_details | {}) + + +class ConfLoss(MultiLoss): + """ + Applies confidence-weighted regression loss using model-predicted confidence values. + + The confidence-weighted loss has the form: + conf_loss = raw_loss * conf - alpha * log(conf) + + Where: + - raw_loss is the original per-pixel loss + - conf is the predicted confidence (higher values = higher confidence) + - alpha is a hyperparameter controlling the regularization strength + + This loss can be selectively applied to specific loss components in factored and multi-view settings. + """ + + def __init__(self, pixel_loss, alpha=1, loss_set_indices=None): + """ + Args: + pixel_loss (MultiLoss): The pixel-level regression loss to be used. + alpha (float): Hyperparameter controlling the confidence regularization strength. + loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to. + Each index selects a specific loss set across all views (with the same rep_type). + If None, defaults to [0] which applies to the first loss set only. + """ + super().__init__() + assert alpha > 0 + self.alpha = alpha + self.pixel_loss = pixel_loss.with_reduction("none") + self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices + + def get_name(self): + return f"ConfLoss({self.pixel_loss})" + + def get_conf_log(self, x): + return x, torch.log(x) + + def compute_loss(self, batch, preds, **kw): + # Init loss list and details + total_loss = 0 + conf_loss_details = {} + running_avg_dict = {} + self_name = type(self.pixel_loss).__name__ + n_views = len(batch) + + # Compute per-pixel loss for each view + losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw) + + # Select specific loss sets based on indices + selected_losses = [] + processed_indices = set() + for idx in self.loss_set_indices: + start_idx = idx * n_views + end_idx = min((idx + 1) * n_views, len(losses)) + selected_losses.extend(losses[start_idx:end_idx]) + processed_indices.update(range(start_idx, end_idx)) + + # Process selected losses with confidence weighting + for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses): + view_idx = loss_idx % n_views # Map to corresponding view index + + if loss.numel() == 0: + # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) + continue + + # Get the confidence and log confidence + if ( + hasattr(self.pixel_loss, "flatten_across_image_only") + and self.pixel_loss.flatten_across_image_only + ): + # Reshape confidence to match the flattened dimensions + conf_reshaped = preds[view_idx]["conf"].view( + preds[view_idx]["conf"].shape[0], -1 + ) + conf, log_conf = self.get_conf_log(conf_reshaped[msk]) + loss = loss[msk] + else: + conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk]) + + # Weight the loss by the confidence + conf_loss = loss * conf - self.alpha * log_conf + + # Only add to total loss and store details if there are valid elements + if conf_loss.numel() > 0: + conf_loss = conf_loss.mean() + total_loss = total_loss + conf_loss + + # Store details + conf_loss_details[ + f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}" + ] = float(conf_loss) + + # Initialize or update running average directly + avg_key = f"{self_name}_{rep_type}_conf_loss_avg" + if avg_key not in conf_loss_details: + conf_loss_details[avg_key] = float(conf_loss) + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] = 1 + else: + valid_views = ( + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] + + 1 + ) + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] = valid_views + conf_loss_details[avg_key] += ( + float(conf_loss) - conf_loss_details[avg_key] + ) / valid_views + + # Add unmodified losses for sets not in selected_losses + for idx, (loss, msk, rep_type) in enumerate(losses): + if idx not in processed_indices: + if msk is not None: + loss_after_masking = loss[msk] + else: + loss_after_masking = loss + if loss_after_masking.numel() > 0: + loss_mean = loss_after_masking.mean() + else: + # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) + loss_mean = 0 + total_loss = total_loss + loss_mean + + return total_loss, dict(**conf_loss_details, **pixel_loss_details) + + +class ExcludeTopNPercentPixelLoss(MultiLoss): + """ + Pixel-level regression loss where for each instance in a batch the top N% of per-pixel loss values are ignored + for the mean loss computation. + Allows selecting which pixel-level regression loss sets to apply the exclusion to. + """ + + def __init__( + self, + pixel_loss, + top_n_percent=5, + apply_to_real_data_only=True, + loss_set_indices=None, + ): + """ + Args: + pixel_loss (MultiLoss): The pixel-level regression loss to be used. + top_n_percent (float): The percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5. + apply_to_real_data_only (bool): Whether to apply the loss only to real world data. Default: True. + loss_set_indices (list or None): Indices of the loss sets to apply the exclusion to. + Each index selects a specific loss set across all views (with the same rep_type). + If None, defaults to [0] which applies to the first loss set only. + """ + super().__init__() + self.pixel_loss = pixel_loss.with_reduction("none") + self.top_n_percent = top_n_percent + self.bottom_n_percent = 100 - top_n_percent + self.apply_to_real_data_only = apply_to_real_data_only + self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices + + def get_name(self): + return f"ExcludeTopNPercentPixelLoss({self.pixel_loss})" + + def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent): + """ + Function to compute the mask for keeping the bottom n percent of per-pixel loss values. + + Args: + tensor (torch.Tensor): The tensor containing the per-pixel loss values. + Shape: (B, N) where B is the batch size and N is the number of total pixels. + mask (torch.Tensor): The mask indicating valid pixels. Shape: (B, N). + + Returns: + torch.Tensor: Flattened tensor containing the bottom n percent of per-pixel loss values. + """ + B, N = tensor.shape + + # Calculate the number of valid elements (where mask is True) + num_valid = mask.sum(dim=1) + + # Calculate the number of elements to keep (n% of valid elements) + num_keep = (num_valid * bottom_n_percent / 100).long() + + # Create a mask for the bottom n% elements + keep_mask = torch.arange(N, device=tensor.device).unsqueeze( + 0 + ) < num_keep.unsqueeze(1) + + # Create a tensor with inf where mask is False + masked_tensor = torch.where( + mask, tensor, torch.tensor(float("inf"), device=tensor.device) + ) + + # Sort the masked tensor along the N dimension + sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False) + + # Get the bottom n% elements + bottom_n_percent_elements = sorted_tensor[keep_mask] + + return bottom_n_percent_elements + + def compute_loss(self, batch, preds, **kw): + # Compute per-pixel loss + losses, details = self.pixel_loss(batch, preds, **kw) + n_views = len(batch) + + # Select specific loss sets based on indices + selected_losses = [] + processed_indices = set() + for idx in self.loss_set_indices: + start_idx = idx * n_views + end_idx = min((idx + 1) * n_views, len(losses)) + selected_losses.extend(losses[start_idx:end_idx]) + processed_indices.update(range(start_idx, end_idx)) + + # Initialize total loss + total_loss = 0.0 + loss_details = {} + running_avg_dict = {} + self_name = type(self.pixel_loss).__name__ + + # Process selected losses with top N percent exclusion + for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses): + view_idx = loss_idx % n_views # Map to corresponding view index + + if loss.numel() == 0: + # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) + continue + + # Create empty list for current view's aggregated tensors + aggregated_losses = [] + + if self.apply_to_real_data_only: + # Get the synthetic and real world data mask + synthetic_mask = batch[view_idx]["is_synthetic"] + real_data_mask = ~batch[view_idx]["is_synthetic"] + else: + # Apply the filtering to all data + synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"]) + real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"]) + + # Process synthetic data + if synthetic_mask.any(): + synthetic_loss = loss[synthetic_mask] + synthetic_msk = msk[synthetic_mask] + aggregated_losses.append(synthetic_loss[synthetic_msk]) + + # Process real data + if real_data_mask.any(): + real_loss = loss[real_data_mask] + real_msk = msk[real_data_mask] + real_bottom_n_percent_loss = self.keep_bottom_n_percent( + real_loss, real_msk, self.bottom_n_percent + ) + aggregated_losses.append(real_bottom_n_percent_loss) + + # Compute view loss + view_loss = torch.cat(aggregated_losses, dim=0) + + # Only add to total loss and store details if there are valid elements + if view_loss.numel() > 0: + view_loss = view_loss.mean() + total_loss = total_loss + view_loss + + # Store details + loss_details[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}" + ] = float(view_loss) + + # Initialize or update running average directly + avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg" + if avg_key not in loss_details: + loss_details[avg_key] = float(view_loss) + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] = 1 + else: + valid_views = ( + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] + + 1 + ) + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] = valid_views + loss_details[avg_key] += ( + float(view_loss) - loss_details[avg_key] + ) / valid_views + + # Add unmodified losses for sets not in selected_losses + for idx, (loss, msk, rep_type) in enumerate(losses): + if idx not in processed_indices: + if msk is not None: + loss_after_masking = loss[msk] + else: + loss_after_masking = loss + if loss_after_masking.numel() > 0: + loss_mean = loss_after_masking.mean() + else: + # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) + loss_mean = 0 + total_loss = total_loss + loss_mean + + return total_loss, dict(**loss_details, **details) + + +class ConfAndExcludeTopNPercentPixelLoss(MultiLoss): + """ + Combined loss that applies ConfLoss to one set of pixel-level regression losses + and ExcludeTopNPercentPixelLoss to another set of pixel-level regression losses. + """ + + def __init__( + self, + pixel_loss, + conf_alpha=1, + top_n_percent=5, + apply_to_real_data_only=True, + conf_loss_set_indices=None, + exclude_loss_set_indices=None, + ): + """ + Args: + pixel_loss (MultiLoss): The pixel-level regression loss to be used. + conf_alpha (float): Alpha parameter for ConfLoss. Default: 1. + top_n_percent (float): Percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5. + apply_to_real_data_only (bool): Whether to apply the exclude loss only to real world data. Default: True. + conf_loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to. + Each index selects a specific loss set across all views (with the same rep_type). + If None, defaults to [0] which applies to the first loss set only. + exclude_loss_set_indices (list or None): Indices of the loss sets to apply top N percent exclusion to. + Each index selects a specific loss set across all views (with the same rep_type). + If None, defaults to [1] which applies to the second loss set only. + """ + super().__init__() + self.pixel_loss = pixel_loss.with_reduction("none") + assert conf_alpha > 0 + self.conf_alpha = conf_alpha + self.top_n_percent = top_n_percent + self.bottom_n_percent = 100 - top_n_percent + self.apply_to_real_data_only = apply_to_real_data_only + self.conf_loss_set_indices = ( + [0] if conf_loss_set_indices is None else conf_loss_set_indices + ) + self.exclude_loss_set_indices = ( + [1] if exclude_loss_set_indices is None else exclude_loss_set_indices + ) + + def get_name(self): + return f"ConfAndExcludeTopNPercentPixelLoss({self.pixel_loss})" + + def get_conf_log(self, x): + return x, torch.log(x) + + def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent): + """ + Function to compute the mask for keeping the bottom n percent of per-pixel loss values. + """ + B, N = tensor.shape + + # Calculate the number of valid elements (where mask is True) + num_valid = mask.sum(dim=1) + + # Calculate the number of elements to keep (n% of valid elements) + num_keep = (num_valid * bottom_n_percent / 100).long() + + # Create a mask for the bottom n% elements + keep_mask = torch.arange(N, device=tensor.device).unsqueeze( + 0 + ) < num_keep.unsqueeze(1) + + # Create a tensor with inf where mask is False + masked_tensor = torch.where( + mask, tensor, torch.tensor(float("inf"), device=tensor.device) + ) + + # Sort the masked tensor along the N dimension + sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False) + + # Get the bottom n% elements + bottom_n_percent_elements = sorted_tensor[keep_mask] + + return bottom_n_percent_elements + + def compute_loss(self, batch, preds, **kw): + # Compute per-pixel loss + losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw) + n_views = len(batch) + + # Select specific loss sets for confidence weighting + conf_selected_losses = [] + conf_processed_indices = set() + for idx in self.conf_loss_set_indices: + start_idx = idx * n_views + end_idx = min((idx + 1) * n_views, len(losses)) + conf_selected_losses.extend(losses[start_idx:end_idx]) + conf_processed_indices.update(range(start_idx, end_idx)) + + # Select specific loss sets for top N percent exclusion + exclude_selected_losses = [] + exclude_processed_indices = set() + for idx in self.exclude_loss_set_indices: + start_idx = idx * n_views + end_idx = min((idx + 1) * n_views, len(losses)) + exclude_selected_losses.extend(losses[start_idx:end_idx]) + exclude_processed_indices.update(range(start_idx, end_idx)) + + # Initialize total loss and details + total_loss = 0 + loss_details = {} + running_avg_dict = {} + self_name = type(self.pixel_loss).__name__ + + # Process selected losses with confidence weighting + for loss_idx, (loss, msk, rep_type) in enumerate(conf_selected_losses): + view_idx = loss_idx % n_views # Map to corresponding view index + + if loss.numel() == 0: + # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for conf loss", force=True) + continue + + # Get the confidence and log confidence + if ( + hasattr(self.pixel_loss, "flatten_across_image_only") + and self.pixel_loss.flatten_across_image_only + ): + # Reshape confidence to match the flattened dimensions + conf_reshaped = preds[view_idx]["conf"].view( + preds[view_idx]["conf"].shape[0], -1 + ) + conf, log_conf = self.get_conf_log(conf_reshaped[msk]) + loss = loss[msk] + else: + conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk]) + + # Weight the loss by the confidence + conf_loss = loss * conf - self.conf_alpha * log_conf + + # Only add to total loss and store details if there are valid elements + if conf_loss.numel() > 0: + conf_loss = conf_loss.mean() + total_loss = total_loss + conf_loss + + # Store details + loss_details[f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}"] = ( + float(conf_loss) + ) + + # Initialize or update running average directly + avg_key = f"{self_name}_{rep_type}_conf_loss_avg" + if avg_key not in loss_details: + loss_details[avg_key] = float(conf_loss) + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] = 1 + else: + valid_views = ( + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] + + 1 + ) + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] = valid_views + loss_details[avg_key] += ( + float(conf_loss) - loss_details[avg_key] + ) / valid_views + + # Process selected losses with top N percent exclusion + for loss_idx, (loss, msk, rep_type) in enumerate(exclude_selected_losses): + view_idx = loss_idx % n_views # Map to corresponding view index + + if loss.numel() == 0: + # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for exclude loss", force=True) + continue + + # Create empty list for current view's aggregated tensors + aggregated_losses = [] + + if self.apply_to_real_data_only: + # Get the synthetic and real world data mask + synthetic_mask = batch[view_idx]["is_synthetic"] + real_data_mask = ~batch[view_idx]["is_synthetic"] + else: + # Apply the filtering to all data + synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"]) + real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"]) + + # Process synthetic data + if synthetic_mask.any(): + synthetic_loss = loss[synthetic_mask] + synthetic_msk = msk[synthetic_mask] + aggregated_losses.append(synthetic_loss[synthetic_msk]) + + # Process real data + if real_data_mask.any(): + real_loss = loss[real_data_mask] + real_msk = msk[real_data_mask] + real_bottom_n_percent_loss = self.keep_bottom_n_percent( + real_loss, real_msk, self.bottom_n_percent + ) + aggregated_losses.append(real_bottom_n_percent_loss) + + # Compute view loss + view_loss = torch.cat(aggregated_losses, dim=0) + + # Only add to total loss and store details if there are valid elements + if view_loss.numel() > 0: + view_loss = view_loss.mean() + total_loss = total_loss + view_loss + + # Store details + loss_details[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}" + ] = float(view_loss) + + # Initialize or update running average directly + avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg" + if avg_key not in loss_details: + loss_details[avg_key] = float(view_loss) + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] = 1 + else: + valid_views = ( + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] + + 1 + ) + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] = valid_views + loss_details[avg_key] += ( + float(view_loss) - loss_details[avg_key] + ) / valid_views + + # Add unmodified losses for sets not processed with either confidence or exclusion + all_processed_indices = conf_processed_indices.union(exclude_processed_indices) + for idx, (loss, msk, rep_type) in enumerate(losses): + if idx not in all_processed_indices: + if msk is not None: + loss_after_masking = loss[msk] + else: + loss_after_masking = loss + if loss_after_masking.numel() > 0: + loss_mean = loss_after_masking.mean() + else: + # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) + loss_mean = 0 + total_loss = total_loss + loss_mean + + return total_loss, dict(**loss_details, **pixel_loss_details) + + +class Regr3D(Criterion, MultiLoss): + """ + Regression Loss for World Frame Pointmaps. + Asymmetric loss where view 1 is supposed to be the anchor. + + For each view i: + Pi = RTi @ Di + lossi = (RTi1 @ pred_Di) - (RT1^-1 @ RTi @ Di) + where RT1 is the anchor view camera pose + """ + + def __init__( + self, + criterion, + norm_mode="?avg_dis", + gt_scale=False, + ambiguous_loss_value=0, + max_metric_scale=False, + loss_in_log=True, + flatten_across_image_only=False, + ): + """ + Initialize the loss criterion for World Frame Pointmaps. + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis". + If prefixed with "?", normalization is only applied to non-metric scale data. + gt_scale (bool): If True, enforce predictions to have the same scale as ground truth. + If False, both GT and predictions are normalized independently. Default: False. + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this + value, it will be treated as non-metric. Default: False (no limit). + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for pointmaps. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + """ + super().__init__(criterion) + if norm_mode.startswith("?"): + # Do no norm pts from metric scale datasets + self.norm_all = False + self.norm_mode = norm_mode[1:] + else: + self.norm_all = True + self.norm_mode = norm_mode + self.gt_scale = gt_scale + self.ambiguous_loss_value = ambiguous_loss_value + self.max_metric_scale = max_metric_scale + self.loss_in_log = loss_in_log + self.flatten_across_image_only = flatten_across_image_only + + def get_all_info(self, batch, preds, dist_clip=None): + n_views = len(batch) + in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) + + # Initialize lists to store points and masks + no_norm_gt_pts = [] + valid_masks = [] + + # Process ground truth points and valid masks + for view_idx in range(n_views): + no_norm_gt_pts.append( + geotrf(in_camera0, batch[view_idx]["pts3d"]) + ) # B,H,W,3 + valid_masks.append(batch[view_idx]["valid_mask"].clone()) + + if dist_clip is not None: + # Points that are too far-away == invalid + for view_idx in range(n_views): + dis = no_norm_gt_pts[view_idx].norm(dim=-1) # (B, H, W) + valid_masks[view_idx] = valid_masks[view_idx] & (dis <= dist_clip) + + # Get predicted points + no_norm_pr_pts = [] + for view_idx in range(n_views): + no_norm_pr_pts.append(preds[view_idx]["pts3d"]) + + if not self.norm_all: + if self.max_metric_scale: + B = valid_masks[0].shape[0] + # Calculate distances to camera for all views + dists_to_cam1 = [] + for view_idx in range(n_views): + dist = torch.where( + valid_masks[view_idx], + torch.norm(no_norm_gt_pts[view_idx], dim=-1), + 0, + ).view(B, -1) + dists_to_cam1.append(dist) + + # Update metric scale flags + metric_scale_mask = batch[0]["is_metric_scale"] + for dist in dists_to_cam1: + metric_scale_mask = metric_scale_mask & ( + dist.max(dim=-1).values < self.max_metric_scale + ) + + for view_idx in range(n_views): + batch[view_idx]["is_metric_scale"] = metric_scale_mask + + non_metric_scale_mask = ~batch[0]["is_metric_scale"] + else: + non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"]) + + # Initialize normalized points + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + + # Normalize 3d points + if self.norm_mode and non_metric_scale_mask.any(): + normalized_pr_pts = normalize_multiple_pointclouds( + [pts[non_metric_scale_mask] for pts in no_norm_pr_pts], + [mask[non_metric_scale_mask] for mask in valid_masks], + self.norm_mode, + ) + for i in range(n_views): + pr_pts[i][non_metric_scale_mask] = normalized_pr_pts[i] + elif non_metric_scale_mask.any(): + for i in range(n_views): + pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][ + non_metric_scale_mask + ] + + if self.norm_mode and not self.gt_scale: + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + normalized_gt_pts = gt_normalization_output[:-1] + norm_factor = gt_normalization_output[-1] + for i in range(n_views): + gt_pts[i] = normalized_gt_pts[i] + pr_pts[i][~non_metric_scale_mask] = ( + no_norm_pr_pts[i][~non_metric_scale_mask] + / norm_factor[~non_metric_scale_mask] + ) + elif ~non_metric_scale_mask.any(): + for i in range(n_views): + gt_pts[i] = no_norm_gt_pts[i] + pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][ + ~non_metric_scale_mask + ] + else: + for i in range(n_views): + gt_pts[i] = no_norm_gt_pts[i] + + # Get ambiguous masks + ambiguous_masks = [] + for view_idx in range(n_views): + ambiguous_masks.append( + (~batch[view_idx]["non_ambiguous_mask"]) & (~valid_masks[view_idx]) + ) + + return gt_pts, pr_pts, valid_masks, ambiguous_masks, {} + + def compute_loss(self, batch, preds, **kw): + gt_pts, pred_pts, masks, ambiguous_masks, monitoring = self.get_all_info( + batch, preds, **kw + ) + n_views = len(batch) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixels as "valid" pixels + masks = [mask | amb_mask for mask, amb_mask in zip(masks, ambiguous_masks)] + + losses = [] + details = {} + running_avg_dict = {} + self_name = type(self).__name__ + + if not self.flatten_across_image_only: + for view_idx in range(n_views): + pred = pred_pts[view_idx][masks[view_idx]] + gt = gt_pts[view_idx][masks[view_idx]] + + if self.loss_in_log: + pred = apply_log_to_norm(pred) + gt = apply_log_to_norm(gt) + + loss = self.criterion(pred, gt) + + if self.ambiguous_loss_value > 0: + loss = torch.where( + ambiguous_masks[view_idx][masks[view_idx]], + self.ambiguous_loss_value, + loss, + ) + + losses.append((loss, masks[view_idx], "pts3d")) + if loss.numel() > 0: + loss_mean = float(loss.mean()) + details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean + # Initialize or update running average directly + avg_key = f"{self_name}_pts3d_avg" + if avg_key not in details: + details[avg_key] = loss_mean + running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1 + else: + valid_views = ( + running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1 + ) + running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views + details[avg_key] += (loss_mean - details[avg_key]) / valid_views + else: + batch_size, _, _, dim = gt_pts[0].shape + + for view_idx in range(n_views): + gt = gt_pts[view_idx].view(batch_size, -1, dim) + pred = pred_pts[view_idx].view(batch_size, -1, dim) + view_mask = masks[view_idx].view(batch_size, -1) + amb_mask = ambiguous_masks[view_idx].view(batch_size, -1) + + if self.loss_in_log: + pred = apply_log_to_norm(pred) + gt = apply_log_to_norm(gt) + + loss = self.criterion(pred, gt) + + if self.ambiguous_loss_value > 0: + loss = torch.where(amb_mask, self.ambiguous_loss_value, loss) + + losses.append((loss, view_mask, "pts3d")) + loss_after_masking = loss[view_mask] + if loss_after_masking.numel() > 0: + loss_mean = float(loss_after_masking.mean()) + details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean + # Initialize or update running average directly + avg_key = f"{self_name}_pts3d_avg" + if avg_key not in details: + details[avg_key] = loss_mean + running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1 + else: + valid_views = ( + running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1 + ) + running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views + details[avg_key] += (loss_mean - details[avg_key]) / valid_views + + return Sum(*losses), (details | monitoring) + + +class PointsPlusScaleRegr3D(Criterion, MultiLoss): + """ + Regression Loss for World Frame Pointmaps & Scale. + """ + + def __init__( + self, + criterion, + norm_predictions=True, + norm_mode="avg_dis", + ambiguous_loss_value=0, + loss_in_log=True, + flatten_across_image_only=False, + world_frame_points_loss_weight=1, + scale_loss_weight=1, + ): + """ + Initialize the loss criterion for World Frame Pointmaps & Scale. + The predicted scene representation is always normalized w.r.t. the frame of view0. + Loss is applied between the predicted metric scale and the ground truth metric scale. + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth, pointmaps and scale. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. + scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + """ + super().__init__(criterion) + self.norm_predictions = norm_predictions + self.norm_mode = norm_mode + self.ambiguous_loss_value = ambiguous_loss_value + self.loss_in_log = loss_in_log + self.flatten_across_image_only = flatten_across_image_only + self.world_frame_points_loss_weight = world_frame_points_loss_weight + self.scale_loss_weight = scale_loss_weight + + def get_all_info(self, batch, preds, dist_clip=None): + """ + Function to get all the information needed to compute the loss. + Returns all quantities normalized w.r.t. camera of view0. + """ + n_views = len(batch) + + # Everything is normalized w.r.t. camera of view0 + # Initialize lists to store data for all views + # Ground truth quantities + in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) + no_norm_gt_pts = [] + valid_masks = [] + # Predicted quantities + no_norm_pr_pts = [] + metric_pr_pts_to_compute_scale = [] + + # Get ground truth & prediction info for all views + for i in range(n_views): + # Get the ground truth + no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) + valid_masks.append(batch[i]["valid_mask"].clone()) + + # Get predictions for normalized loss + if "metric_scaling_factor" in preds[i].keys(): + # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans + # This detaches the predicted metric scaling factor from the geometry based loss + curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + else: + curr_view_no_norm_pr_pts = preds[i]["pts3d"] + no_norm_pr_pts.append(curr_view_no_norm_pr_pts) + + # Get the predicted metric scale points + if "metric_scaling_factor" in preds[i].keys(): + # Detach the raw predicted points so that the scale loss is only applied to the scaling factor + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.detach() + * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1) + ) + else: + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.clone() + ) + metric_pr_pts_to_compute_scale.append( + curr_view_metric_pr_pts_to_compute_scale + ) + + if dist_clip is not None: + # Points that are too far-away == invalid + for i in range(n_views): + dis = no_norm_gt_pts[i].norm(dim=-1) + valid_masks[i] = valid_masks[i] & (dis <= dist_clip) + + # Initialize normalized tensors + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + + # Normalize the predicted points if specified + if self.norm_predictions: + pr_normalization_output = normalize_multiple_pointclouds( + no_norm_pr_pts, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_pts_norm = pr_normalization_output[:-1] + + # Normalize the ground truth points + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + gt_pts_norm = gt_normalization_output[:-1] + gt_norm_factor = gt_normalization_output[-1] + + for i in range(n_views): + if self.norm_predictions: + # Assign the normalized predictions + pr_pts[i] = pr_pts_norm[i] + else: + pr_pts[i] = no_norm_pr_pts[i] + # Assign the normalized ground truth quantities + gt_pts[i] = gt_pts_norm[i] + + # Get the mask indicating ground truth metric scale quantities + metric_scale_mask = batch[0]["is_metric_scale"] + valid_gt_norm_factor_mask = ( + gt_norm_factor[:, 0, 0, 0] > 1e-8 + ) # Mask out cases where depth for all views is invalid + valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask + + if valid_metric_scale_mask.any(): + # Compute the scale norm factor using the predicted metric scale points + metric_pr_normalization_output = normalize_multiple_pointclouds( + metric_pr_pts_to_compute_scale, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_metric_norm_factor = metric_pr_normalization_output[-1] + + # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities + gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask] + pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask] + else: + gt_metric_norm_factor = None + pr_metric_norm_factor = None + + # Get ambiguous masks + ambiguous_masks = [] + for i in range(n_views): + ambiguous_masks.append( + (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) + ) + + # Pack into info dicts + gt_info = [] + pred_info = [] + for i in range(n_views): + gt_info.append( + { + "pts3d": gt_pts[i], + } + ) + pred_info.append( + { + "pts3d": pr_pts[i], + } + ) + + return ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) + + def compute_loss(self, batch, preds, **kw): + ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixel as "valid" pixels... + valid_masks = [ + mask | ambig_mask + for mask, ambig_mask in zip(valid_masks, ambiguous_masks) + ] + + pts3d_losses = [] + + for i in range(n_views): + # Get the predicted dense quantities + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] + gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, pts_dim = gt_info[i]["pts3d"].shape + gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) + pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space if specified + if self.loss_in_log: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_pts3d = apply_log_to_norm(pred_pts3d) + + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + pts3d_losses[i], + ) + + # Compute the scale loss + if gt_metric_norm_factor is not None: + if self.loss_in_log: + gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) + pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) + scale_loss = ( + self.criterion( + pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" + ) + * self.scale_loss_weight + ) + else: + scale_loss = None + + # Use helper function to generate loss terms and details + + losses_dict = { + "pts3d": { + "values": pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + "scale": { + "values": scale_loss, + "use_mask": False, + "is_multi_view": False, + }, + } + + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class NormalGMLoss(MultiLoss): + """ + Normal & Gradient Matching Loss for Monocular Depth Training. + """ + + def __init__( + self, + norm_predictions=True, + norm_mode="avg_dis", + apply_normal_and_gm_loss_to_synthetic_data_only=True, + ): + """ + Initialize the loss criterion for Normal & Gradient Matching Loss (currently only valid for 1 view). + Computes: + (1) Normal Loss over the PointMap (naturally will be in local frame) in euclidean coordinates, + (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) + + Args: + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. + If False, apply the normal and gm loss to all data. Default: True. + """ + super().__init__() + self.norm_predictions = norm_predictions + self.norm_mode = norm_mode + self.apply_normal_and_gm_loss_to_synthetic_data_only = ( + apply_normal_and_gm_loss_to_synthetic_data_only + ) + + def get_all_info(self, batch, preds, dist_clip=None): + """ + Function to get all the information needed to compute the loss. + Returns all quantities normalized. + """ + n_views = len(batch) + assert n_views == 1, ( + "Normal & Gradient Matching Loss Class only supports 1 view" + ) + + # Everything is normalized w.r.t. camera of view1 + in_camera1 = closed_form_pose_inverse(batch[0]["camera_pose"]) + + # Initialize lists to store data for all views + no_norm_gt_pts = [] + valid_masks = [] + no_norm_pr_pts = [] + + # Get ground truth & prediction info for all views + for i in range(n_views): + # Get ground truth + no_norm_gt_pts.append(geotrf(in_camera1, batch[i]["pts3d"])) + valid_masks.append(batch[i]["valid_mask"].clone()) + + # Get predictions for normalized loss + if "metric_scaling_factor" in preds[i].keys(): + # Divide by the predicted metric scaling factor to get the raw predicted points + # This detaches the predicted metric scaling factor from the geometry based loss + curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + else: + curr_view_no_norm_pr_pts = preds[i]["pts3d"] + no_norm_pr_pts.append(curr_view_no_norm_pr_pts) + + if dist_clip is not None: + # Points that are too far-away == invalid + for i in range(n_views): + dis = no_norm_gt_pts[i].norm(dim=-1) + valid_masks[i] = valid_masks[i] & (dis <= dist_clip) + + # Initialize normalized tensors + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + + # Normalize the predicted points if specified + if self.norm_predictions: + pr_normalization_output = normalize_multiple_pointclouds( + no_norm_pr_pts, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_pts_norm = pr_normalization_output[:-1] + + # Normalize the ground truth points + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + gt_pts_norm = gt_normalization_output[:-1] + + for i in range(n_views): + if self.norm_predictions: + # Assign the normalized predictions + pr_pts[i] = pr_pts_norm[i] + else: + # Assign the raw predicted points + pr_pts[i] = no_norm_pr_pts[i] + # Assign the normalized ground truth + gt_pts[i] = gt_pts_norm[i] + + return gt_pts, pr_pts, valid_masks + + def compute_loss(self, batch, preds, **kw): + gt_pts, pred_pts, valid_masks = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + assert n_views == 1, ( + "Normal & Gradient Matching Loss Class only supports 1 view" + ) + + normal_losses = [] + gradient_matching_losses = [] + details = {} + running_avg_dict = {} + self_name = type(self).__name__ + + for i in range(n_views): + # Get the local frame points, log space depth_z & valid masks + pred_local_pts3d = pred_pts[i] + pred_depth_z = pred_local_pts3d[..., 2:] + pred_depth_z = apply_log_to_norm(pred_depth_z) + gt_local_pts3d = gt_pts[i] + gt_depth_z = gt_local_pts3d[..., 2:] + gt_depth_z = apply_log_to_norm(gt_depth_z) + valid_mask_for_normal_gm_loss = valid_masks[i].clone() + + # Update the validity mask for normal & gm loss based on the synthetic data mask if required + if self.apply_normal_and_gm_loss_to_synthetic_data_only: + synthetic_mask = batch[i]["is_synthetic"] # (B, ) + synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) + synthetic_mask = synthetic_mask.expand( + -1, pred_depth_z.shape[1], pred_depth_z.shape[2] + ) # (B, H, W) + valid_mask_for_normal_gm_loss = ( + valid_mask_for_normal_gm_loss & synthetic_mask + ) + + # Compute the normal loss + normal_loss = compute_normal_loss( + pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() + ) + normal_losses.append(normal_loss) + + # Compute the gradient matching loss + gradient_matching_loss = compute_gradient_matching_loss( + pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() + ) + gradient_matching_losses.append(gradient_matching_loss) + + # Add loss details if only valid values are present + # Initialize or update running average directly + # Normal loss details + if float(normal_loss) > 0: + details[f"{self_name}_normal_view{i + 1}"] = float(normal_loss) + normal_avg_key = f"{self_name}_normal_avg" + if normal_avg_key not in details: + details[normal_avg_key] = float(normal_losses[i]) + running_avg_dict[f"{self_name}_normal_valid_views"] = 1 + else: + normal_valid_views = ( + running_avg_dict[f"{self_name}_normal_valid_views"] + 1 + ) + running_avg_dict[f"{self_name}_normal_valid_views"] = ( + normal_valid_views + ) + details[normal_avg_key] += ( + float(normal_losses[i]) - details[normal_avg_key] + ) / normal_valid_views + + # Gradient Matching loss details + if float(gradient_matching_loss) > 0: + details[f"{self_name}_gradient_matching_view{i + 1}"] = float( + gradient_matching_loss + ) + # For gradient matching loss + gm_avg_key = f"{self_name}_gradient_matching_avg" + if gm_avg_key not in details: + details[gm_avg_key] = float(gradient_matching_losses[i]) + running_avg_dict[f"{self_name}_gm_valid_views"] = 1 + else: + gm_valid_views = running_avg_dict[f"{self_name}_gm_valid_views"] + 1 + running_avg_dict[f"{self_name}_gm_valid_views"] = gm_valid_views + details[gm_avg_key] += ( + float(gradient_matching_losses[i]) - details[gm_avg_key] + ) / gm_valid_views + + # Put the losses together + loss_terms = [] + for i in range(n_views): + loss_terms.append((normal_losses[i], None, "normal")) + loss_terms.append((gradient_matching_losses[i], None, "gradient_matching")) + losses = Sum(*loss_terms) + + return losses, details + + +class FactoredGeometryRegr3D(Criterion, MultiLoss): + """ + Regression Loss for Factored Geometry. + """ + + def __init__( + self, + criterion, + norm_mode="?avg_dis", + gt_scale=False, + ambiguous_loss_value=0, + max_metric_scale=False, + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + cam_frame_points_loss_weight=1, + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + compute_pairwise_relative_pose_loss=False, + convert_predictions_to_view0_frame=False, + compute_world_frame_points_loss=True, + world_frame_points_loss_weight=1, + ): + """ + Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose), + and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps. + If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order: + (1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans. + Else, the pixel-level losses are returned in the following order: + (1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans. + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis". + If prefixed with "?", normalization is only applied to non-metric scale data. + gt_scale (bool): If True, enforce predictions to have the same scale as ground truth. + If False, both GT and predictions are normalized independently. Default: False. + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this + value, it will be treated as non-metric. Default: False (no limit). + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth and pointmaps. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the + exhaustive pairwise relative poses. Default: False. + convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. + Use this if the predictions are not already in the view0 frame. Default: False. + compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. + world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. + """ + super().__init__(criterion) + if norm_mode.startswith("?"): + # Do no norm pts from metric scale datasets + self.norm_all = False + self.norm_mode = norm_mode[1:] + else: + self.norm_all = True + self.norm_mode = norm_mode + self.gt_scale = gt_scale + self.ambiguous_loss_value = ambiguous_loss_value + self.max_metric_scale = max_metric_scale + self.loss_in_log = loss_in_log + self.flatten_across_image_only = flatten_across_image_only + self.depth_type_for_loss = depth_type_for_loss + assert self.depth_type_for_loss in [ + "depth_along_ray", + "depth_z", + ], "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" + self.cam_frame_points_loss_weight = cam_frame_points_loss_weight + self.depth_loss_weight = depth_loss_weight + self.ray_directions_loss_weight = ray_directions_loss_weight + self.pose_quats_loss_weight = pose_quats_loss_weight + self.pose_trans_loss_weight = pose_trans_loss_weight + self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss + self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame + self.compute_world_frame_points_loss = compute_world_frame_points_loss + self.world_frame_points_loss_weight = world_frame_points_loss_weight + + def get_all_info(self, batch, preds, dist_clip=None): + """ + Function to get all the information needed to compute the loss. + Returns all quantities normalized w.r.t. camera of view0. + """ + n_views = len(batch) + + # Everything is normalized w.r.t. camera of view0 + # Initialize lists to store data for all views + # Ground truth quantities + in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) + no_norm_gt_pts = [] + no_norm_gt_pts_cam = [] + no_norm_gt_depth = [] + no_norm_gt_pose_trans = [] + valid_masks = [] + gt_ray_directions = [] + gt_pose_quats = [] + # Predicted quantities + if self.convert_predictions_to_view0_frame: + # Get the camera transform to convert quantities to view0 frame + pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze( + 0 + ) + batch_size = preds[0]["cam_quats"].shape[0] + pred_camera0 = pred_camera0.repeat(batch_size, 1, 1) + pred_camera0_rot = quaternion_to_rotation_matrix( + preds[0]["cam_quats"].clone() + ) + pred_camera0[..., :3, :3] = pred_camera0_rot + pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone() + pred_in_camera0 = closed_form_pose_inverse(pred_camera0) + no_norm_pr_pts = [] + no_norm_pr_pts_cam = [] + no_norm_pr_depth = [] + no_norm_pr_pose_trans = [] + pr_ray_directions = [] + pr_pose_quats = [] + + # Get ground truth & prediction info for all views + for i in range(n_views): + # Get ground truth + no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) + valid_masks.append(batch[i]["valid_mask"].clone()) + no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"]) + gt_ray_directions.append(batch[i]["ray_directions_cam"]) + if self.depth_type_for_loss == "depth_along_ray": + no_norm_gt_depth.append(batch[i]["depth_along_ray"]) + elif self.depth_type_for_loss == "depth_z": + no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:]) + if i == 0: + # For view0, initialize identity pose + gt_pose_quats.append( + torch.tensor( + [0, 0, 0, 1], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + no_norm_gt_pose_trans.append( + torch.tensor( + [0, 0, 0], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + else: + # For other views, transform pose to view0's frame + gt_pose_quats_world = batch[i]["camera_pose_quats"] + no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"] + gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = ( + transform_pose_using_quats_and_trans_2_to_1( + batch[0]["camera_pose_quats"], + batch[0]["camera_pose_trans"], + gt_pose_quats_world, + no_norm_gt_pose_trans_world, + ) + ) + gt_pose_quats.append(gt_pose_quats_in_view0) + no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0) + + # Get the local predictions + no_norm_pr_pts_cam.append(preds[i]["pts3d_cam"]) + pr_ray_directions.append(preds[i]["ray_directions"]) + if self.depth_type_for_loss == "depth_along_ray": + no_norm_pr_depth.append(preds[i]["depth_along_ray"]) + elif self.depth_type_for_loss == "depth_z": + no_norm_pr_depth.append(preds[i]["pts3d_cam"][..., 2:]) + + # Get the predicted global predictions in view0's frame + if self.convert_predictions_to_view0_frame: + # Convert predictions to view0 frame + pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"]) + pr_pose_quats_in_view0, pr_pose_trans_in_view0 = ( + transform_pose_using_quats_and_trans_2_to_1( + preds[0]["cam_quats"], + preds[0]["cam_trans"], + preds[i]["cam_quats"], + preds[i]["cam_trans"], + ) + ) + no_norm_pr_pts.append(pr_pts3d_in_view0) + no_norm_pr_pose_trans.append(pr_pose_trans_in_view0) + pr_pose_quats.append(pr_pose_quats_in_view0) + else: + # Predictions are already in view0 frame + no_norm_pr_pts.append(preds[i]["pts3d"]) + no_norm_pr_pose_trans.append(preds[i]["cam_trans"]) + pr_pose_quats.append(preds[i]["cam_quats"]) + + if dist_clip is not None: + # Points that are too far-away == invalid + for i in range(n_views): + dis = no_norm_gt_pts[i].norm(dim=-1) + valid_masks[i] = valid_masks[i] & (dis <= dist_clip) + + # Handle metric scale + if not self.norm_all: + if self.max_metric_scale: + B = valid_masks[0].shape[0] + dists_to_cam1 = [] + for i in range(n_views): + dists_to_cam1.append( + torch.where( + valid_masks[i], torch.norm(no_norm_gt_pts[i], dim=-1), 0 + ).view(B, -1) + ) + + batch[0]["is_metric_scale"] = batch[0]["is_metric_scale"] + for dist in dists_to_cam1: + batch[0]["is_metric_scale"] &= ( + dist.max(dim=-1).values < self.max_metric_scale + ) + + for i in range(1, n_views): + batch[i]["is_metric_scale"] = batch[0]["is_metric_scale"] + + non_metric_scale_mask = ~batch[0]["is_metric_scale"] + else: + non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"]) + + # Initialize normalized tensors + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam] + gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth] + gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans] + + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam] + pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth] + pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans] + + # Normalize points + if self.norm_mode and non_metric_scale_mask.any(): + pr_normalization_output = normalize_multiple_pointclouds( + [pts[non_metric_scale_mask] for pts in no_norm_pr_pts], + [mask[non_metric_scale_mask] for mask in valid_masks], + self.norm_mode, + ret_factor=True, + ) + pr_pts_norm = pr_normalization_output[:-1] + pr_norm_factor = pr_normalization_output[-1] + + for i in range(n_views): + pr_pts[i][non_metric_scale_mask] = pr_pts_norm[i] + pr_pts_cam[i][non_metric_scale_mask] = ( + no_norm_pr_pts_cam[i][non_metric_scale_mask] / pr_norm_factor + ) + pr_depth[i][non_metric_scale_mask] = ( + no_norm_pr_depth[i][non_metric_scale_mask] / pr_norm_factor + ) + pr_pose_trans[i][non_metric_scale_mask] = ( + no_norm_pr_pose_trans[i][non_metric_scale_mask] + / pr_norm_factor[:, :, 0, 0] + ) + + elif non_metric_scale_mask.any(): + for i in range(n_views): + pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][ + non_metric_scale_mask + ] + pr_pts_cam[i][non_metric_scale_mask] = no_norm_pr_pts_cam[i][ + non_metric_scale_mask + ] + pr_depth[i][non_metric_scale_mask] = no_norm_pr_depth[i][ + non_metric_scale_mask + ] + pr_pose_trans[i][non_metric_scale_mask] = no_norm_pr_pose_trans[i][ + non_metric_scale_mask + ] + + if self.norm_mode and not self.gt_scale: + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + gt_pts_norm = gt_normalization_output[:-1] + norm_factor = gt_normalization_output[-1] + + for i in range(n_views): + gt_pts[i] = gt_pts_norm[i] + gt_pts_cam[i] = no_norm_gt_pts_cam[i] / norm_factor + gt_depth[i] = no_norm_gt_depth[i] / norm_factor + gt_pose_trans[i] = no_norm_gt_pose_trans[i] / norm_factor[:, :, 0, 0] + + pr_pts[i][~non_metric_scale_mask] = ( + no_norm_pr_pts[i][~non_metric_scale_mask] + / norm_factor[~non_metric_scale_mask] + ) + pr_pts_cam[i][~non_metric_scale_mask] = ( + no_norm_pr_pts_cam[i][~non_metric_scale_mask] + / norm_factor[~non_metric_scale_mask] + ) + pr_depth[i][~non_metric_scale_mask] = ( + no_norm_pr_depth[i][~non_metric_scale_mask] + / norm_factor[~non_metric_scale_mask] + ) + pr_pose_trans[i][~non_metric_scale_mask] = ( + no_norm_pr_pose_trans[i][~non_metric_scale_mask] + / norm_factor[~non_metric_scale_mask][:, :, 0, 0] + ) + + elif ~non_metric_scale_mask.any(): + for i in range(n_views): + gt_pts[i] = no_norm_gt_pts[i] + gt_pts_cam[i] = no_norm_gt_pts_cam[i] + gt_depth[i] = no_norm_gt_depth[i] + gt_pose_trans[i] = no_norm_gt_pose_trans[i] + pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][ + ~non_metric_scale_mask + ] + pr_pts_cam[i][~non_metric_scale_mask] = no_norm_pr_pts_cam[i][ + ~non_metric_scale_mask + ] + pr_depth[i][~non_metric_scale_mask] = no_norm_pr_depth[i][ + ~non_metric_scale_mask + ] + pr_pose_trans[i][~non_metric_scale_mask] = no_norm_pr_pose_trans[i][ + ~non_metric_scale_mask + ] + else: + for i in range(n_views): + gt_pts[i] = no_norm_gt_pts[i] + gt_pts_cam[i] = no_norm_gt_pts_cam[i] + gt_depth[i] = no_norm_gt_depth[i] + gt_pose_trans[i] = no_norm_gt_pose_trans[i] + + # Get ambiguous masks + ambiguous_masks = [] + for i in range(n_views): + ambiguous_masks.append( + (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) + ) + + # Pack into info dicts + gt_info = [] + pred_info = [] + for i in range(n_views): + gt_info.append( + { + "ray_directions": gt_ray_directions[i], + self.depth_type_for_loss: gt_depth[i], + "pose_trans": gt_pose_trans[i], + "pose_quats": gt_pose_quats[i], + "pts3d": gt_pts[i], + "pts3d_cam": gt_pts_cam[i], + } + ) + pred_info.append( + { + "ray_directions": pr_ray_directions[i], + self.depth_type_for_loss: pr_depth[i], + "pose_trans": pr_pose_trans[i], + "pose_quats": pr_pose_quats[i], + "pts3d": pr_pts[i], + "pts3d_cam": pr_pts_cam[i], + } + ) + + return gt_info, pred_info, valid_masks, ambiguous_masks + + def compute_loss(self, batch, preds, **kw): + gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info( + batch, preds, **kw + ) + n_views = len(batch) + + # Mask out samples in the batch where the gt depth validity mask is entirely zero + valid_norm_factor_masks = [ + mask.sum(dim=(1, 2)) > 0 for mask in valid_masks + ] # List of (B,) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixel as "valid" pixels... + valid_masks = [ + mask | ambig_mask + for mask, ambig_mask in zip(valid_masks, ambiguous_masks) + ] + + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + cam_pts3d_losses = [] + if self.compute_world_frame_points_loss: + pts3d_losses = [] + + for i in range(n_views): + # Get the predicted dense quantities + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_ray_directions = pred_info[i]["ray_directions"] + gt_ray_directions = gt_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] + gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] + pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] + gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] + if self.compute_world_frame_points_loss: + pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] + gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape + gt_ray_directions = gt_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + pred_ray_directions = pred_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] + gt_depth = gt_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + pred_depth = pred_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] + gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) + pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( + batch_size, -1, cam_pts_dim + ) + if self.compute_world_frame_points_loss: + pts_dim = gt_info[i]["pts3d"].shape[-1] + gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) + pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space for depth if specified + if self.loss_in_log: + gt_depth = apply_log_to_norm(gt_depth) + pred_depth = apply_log_to_norm(pred_depth) + gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) + pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) + if self.compute_world_frame_points_loss: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_pts3d = apply_log_to_norm(pred_pts3d) + + if self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions, gt_ray_directions, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute camera frame point loss + cam_pts3d_loss = self.criterion( + pred_cam_pts3d, gt_cam_pts3d, factor="points" + ) + cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight + cam_pts3d_losses.append(cam_pts3d_loss) + + if self.compute_world_frame_points_loss: + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + depth_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + depth_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + pts3d_losses[i], + ) + + # Use helper function to generate loss terms and details + if self.compute_world_frame_points_loss: + losses_dict = { + "pts3d": { + "values": pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + } + else: + losses_dict = {} + losses_dict.update( + { + "cam_pts3d": { + "values": cam_pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": False, + "is_multi_view": True, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D): + """ + Regression, Normals & Gradient Matching Loss for Factored Geometry. + """ + + def __init__( + self, + criterion, + norm_mode="?avg_dis", + gt_scale=False, + ambiguous_loss_value=0, + max_metric_scale=False, + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + cam_frame_points_loss_weight=1, + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + compute_pairwise_relative_pose_loss=False, + convert_predictions_to_view0_frame=False, + compute_world_frame_points_loss=True, + world_frame_points_loss_weight=1, + apply_normal_and_gm_loss_to_synthetic_data_only=True, + normal_loss_weight=1, + gm_loss_weight=1, + ): + """ + Initialize the loss criterion for Factored Geometry (see parent class for details). + Additionally computes: + (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates, + (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_mode (str): Normalization mode for scene representation. Default: "avg_dis". + If prefixed with "?", normalization is only applied to non-metric scale data. + gt_scale (bool): If True, enforce predictions to have the same scale as ground truth. + If False, both GT and predictions are normalized independently. Default: False. + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this + value, it will be treated as non-metric. Default: False (no limit). + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth and pointmaps. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the + exhaustive pairwise relative poses. Default: False. + convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. + Use this if the predictions are not already in the view0 frame. Default: False. + compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. + world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. + apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. + If False, apply the normal and gm loss to all data. Default: True. + normal_loss_weight (float): Weight to use for the normal loss. Default: 1. + gm_loss_weight (float): Weight to use for the gm loss. Default: 1. + """ + super().__init__( + criterion=criterion, + norm_mode=norm_mode, + gt_scale=gt_scale, + ambiguous_loss_value=ambiguous_loss_value, + max_metric_scale=max_metric_scale, + loss_in_log=loss_in_log, + flatten_across_image_only=flatten_across_image_only, + depth_type_for_loss=depth_type_for_loss, + cam_frame_points_loss_weight=cam_frame_points_loss_weight, + depth_loss_weight=depth_loss_weight, + ray_directions_loss_weight=ray_directions_loss_weight, + pose_quats_loss_weight=pose_quats_loss_weight, + pose_trans_loss_weight=pose_trans_loss_weight, + compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss, + convert_predictions_to_view0_frame=convert_predictions_to_view0_frame, + compute_world_frame_points_loss=compute_world_frame_points_loss, + world_frame_points_loss_weight=world_frame_points_loss_weight, + ) + self.apply_normal_and_gm_loss_to_synthetic_data_only = ( + apply_normal_and_gm_loss_to_synthetic_data_only + ) + self.normal_loss_weight = normal_loss_weight + self.gm_loss_weight = gm_loss_weight + + def compute_loss(self, batch, preds, **kw): + gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info( + batch, preds, **kw + ) + n_views = len(batch) + + # Mask out samples in the batch where the gt depth validity mask is entirely zero + valid_norm_factor_masks = [ + mask.sum(dim=(1, 2)) > 0 for mask in valid_masks + ] # List of (B,) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixel as "valid" pixels... + valid_masks = [ + mask | ambig_mask + for mask, ambig_mask in zip(valid_masks, ambiguous_masks) + ] + + normal_losses = [] + gradient_matching_losses = [] + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + cam_pts3d_losses = [] + if self.compute_world_frame_points_loss: + pts3d_losses = [] + + for i in range(n_views): + # Get the camera frame points, log space depth_z & valid masks + pred_local_pts3d = pred_info[i]["pts3d_cam"] + pred_depth_z = pred_local_pts3d[..., 2:] + pred_depth_z = apply_log_to_norm(pred_depth_z) + gt_local_pts3d = gt_info[i]["pts3d_cam"] + gt_depth_z = gt_local_pts3d[..., 2:] + gt_depth_z = apply_log_to_norm(gt_depth_z) + valid_mask_for_normal_gm_loss = valid_masks[i].clone() + + # Update the validity mask for normal & gm loss based on the synthetic data mask if required + if self.apply_normal_and_gm_loss_to_synthetic_data_only: + synthetic_mask = batch[i]["is_synthetic"] # (B, ) + synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) + synthetic_mask = synthetic_mask.expand( + -1, pred_depth_z.shape[1], pred_depth_z.shape[2] + ) # (B, H, W) + valid_mask_for_normal_gm_loss = ( + valid_mask_for_normal_gm_loss & synthetic_mask + ) + + # Compute the normal loss + normal_loss = compute_normal_loss( + pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() + ) + normal_loss = normal_loss * self.normal_loss_weight + normal_losses.append(normal_loss) + + # Compute the gradient matching loss + gradient_matching_loss = compute_gradient_matching_loss( + pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() + ) + gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight + gradient_matching_losses.append(gradient_matching_loss) + + # Get the predicted dense quantities + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_ray_directions = pred_info[i]["ray_directions"] + gt_ray_directions = gt_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] + gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] + pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] + gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] + if self.compute_world_frame_points_loss: + pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] + gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape + gt_ray_directions = gt_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + pred_ray_directions = pred_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] + gt_depth = gt_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + pred_depth = pred_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] + gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) + pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( + batch_size, -1, cam_pts_dim + ) + if self.compute_world_frame_points_loss: + pts_dim = gt_info[i]["pts3d"].shape[-1] + gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) + pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space for depth if specified + if self.loss_in_log: + gt_depth = apply_log_to_norm(gt_depth) + pred_depth = apply_log_to_norm(pred_depth) + gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) + pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) + if self.compute_world_frame_points_loss: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_pts3d = apply_log_to_norm(pred_pts3d) + + if self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions, gt_ray_directions, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute camera frame point loss + cam_pts3d_loss = self.criterion( + pred_cam_pts3d, gt_cam_pts3d, factor="points" + ) + cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight + cam_pts3d_losses.append(cam_pts3d_loss) + + if self.compute_world_frame_points_loss: + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + depth_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + depth_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + pts3d_losses[i], + ) + + # Use helper function to generate loss terms and details + if self.compute_world_frame_points_loss: + losses_dict = { + "pts3d": { + "values": pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + } + else: + losses_dict = {} + losses_dict.update( + { + "cam_pts3d": { + "values": cam_pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": False, + "is_multi_view": True, + }, + "normal": { + "values": normal_losses, + "use_mask": False, + "is_multi_view": True, + }, + "gradient_matching": { + "values": gradient_matching_losses, + "use_mask": False, + "is_multi_view": True, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class FactoredGeometryScaleRegr3D(Criterion, MultiLoss): + """ + Regression Loss for Factored Geometry & Scale. + """ + + def __init__( + self, + criterion, + norm_predictions=True, + norm_mode="avg_dis", + ambiguous_loss_value=0, + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + cam_frame_points_loss_weight=1, + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + scale_loss_weight=1, + compute_pairwise_relative_pose_loss=False, + convert_predictions_to_view0_frame=False, + compute_world_frame_points_loss=True, + world_frame_points_loss_weight=1, + ): + """ + Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose), Scale + and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps. + If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order: + (1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans, (7) scale. + Else, the pixel-level losses are returned in the following order: + (1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans, (6) scale. + The predicited scene representation is always normalized w.r.t. the frame of view0. + Loss is applied between the predicted metric scale and the ground truth metric scale. + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth, pointmaps and scale. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the + exhaustive pairwise relative poses. Default: False. + convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. + Use this if the predictions are not already in the view0 frame. Default: False. + compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. + world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. + """ + super().__init__(criterion) + self.norm_predictions = norm_predictions + self.norm_mode = norm_mode + self.ambiguous_loss_value = ambiguous_loss_value + self.loss_in_log = loss_in_log + self.flatten_across_image_only = flatten_across_image_only + self.depth_type_for_loss = depth_type_for_loss + assert self.depth_type_for_loss in [ + "depth_along_ray", + "depth_z", + ], "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" + self.cam_frame_points_loss_weight = cam_frame_points_loss_weight + self.depth_loss_weight = depth_loss_weight + self.ray_directions_loss_weight = ray_directions_loss_weight + self.pose_quats_loss_weight = pose_quats_loss_weight + self.pose_trans_loss_weight = pose_trans_loss_weight + self.scale_loss_weight = scale_loss_weight + self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss + self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame + self.compute_world_frame_points_loss = compute_world_frame_points_loss + self.world_frame_points_loss_weight = world_frame_points_loss_weight + + def get_all_info(self, batch, preds, dist_clip=None): + """ + Function to get all the information needed to compute the loss. + Returns all quantities normalized w.r.t. camera of view0. + """ + n_views = len(batch) + + # Everything is normalized w.r.t. camera of view0 + # Intialize lists to store data for all views + # Ground truth quantities + in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) + no_norm_gt_pts = [] + no_norm_gt_pts_cam = [] + no_norm_gt_depth = [] + no_norm_gt_pose_trans = [] + valid_masks = [] + gt_ray_directions = [] + gt_pose_quats = [] + # Predicted quantities + if self.convert_predictions_to_view0_frame: + # Get the camera transform to convert quantities to view0 frame + pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze( + 0 + ) + batch_size = preds[0]["cam_quats"].shape[0] + pred_camera0 = pred_camera0.repeat(batch_size, 1, 1) + pred_camera0_rot = quaternion_to_rotation_matrix( + preds[0]["cam_quats"].clone() + ) + pred_camera0[..., :3, :3] = pred_camera0_rot + pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone() + pred_in_camera0 = closed_form_pose_inverse(pred_camera0) + no_norm_pr_pts = [] + no_norm_pr_pts_cam = [] + no_norm_pr_depth = [] + no_norm_pr_pose_trans = [] + pr_ray_directions = [] + pr_pose_quats = [] + metric_pr_pts_to_compute_scale = [] + + # Get ground truth & prediction info for all views + for i in range(n_views): + # Get the ground truth + no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) + valid_masks.append(batch[i]["valid_mask"].clone()) + no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"]) + gt_ray_directions.append(batch[i]["ray_directions_cam"]) + if self.depth_type_for_loss == "depth_along_ray": + no_norm_gt_depth.append(batch[i]["depth_along_ray"]) + elif self.depth_type_for_loss == "depth_z": + no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:]) + if i == 0: + # For view0, initialize identity pose + gt_pose_quats.append( + torch.tensor( + [0, 0, 0, 1], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + no_norm_gt_pose_trans.append( + torch.tensor( + [0, 0, 0], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + else: + # For other views, transform pose to view0's frame + gt_pose_quats_world = batch[i]["camera_pose_quats"] + no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"] + gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = ( + transform_pose_using_quats_and_trans_2_to_1( + batch[0]["camera_pose_quats"], + batch[0]["camera_pose_trans"], + gt_pose_quats_world, + no_norm_gt_pose_trans_world, + ) + ) + gt_pose_quats.append(gt_pose_quats_in_view0) + no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0) + + # Get the global predictions in view0's frame + if self.convert_predictions_to_view0_frame: + # Convert predictions to view0 frame + pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"]) + pr_pose_quats_in_view0, pr_pose_trans_in_view0 = ( + transform_pose_using_quats_and_trans_2_to_1( + preds[0]["cam_quats"], + preds[0]["cam_trans"], + preds[i]["cam_quats"], + preds[i]["cam_trans"], + ) + ) + else: + # Predictions are already in view0 frame + pr_pts3d_in_view0 = preds[i]["pts3d"] + pr_pose_trans_in_view0 = preds[i]["cam_trans"] + pr_pose_quats_in_view0 = preds[i]["cam_quats"] + + # Get predictions for normalized loss + if self.depth_type_for_loss == "depth_along_ray": + curr_view_no_norm_depth = preds[i]["depth_along_ray"] + elif self.depth_type_for_loss == "depth_z": + curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:] + if "metric_scaling_factor" in preds[i].keys(): + # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans + # This detaches the predicted metric scaling factor from the geometry based loss + curr_view_no_norm_pr_pts = pr_pts3d_in_view0 / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_pr_pose_trans = ( + pr_pose_trans_in_view0 / preds[i]["metric_scaling_factor"] + ) + else: + curr_view_no_norm_pr_pts = pr_pts3d_in_view0 + curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] + curr_view_no_norm_depth = curr_view_no_norm_depth + curr_view_no_norm_pr_pose_trans = pr_pose_trans_in_view0 + no_norm_pr_pts.append(curr_view_no_norm_pr_pts) + no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam) + no_norm_pr_depth.append(curr_view_no_norm_depth) + no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans) + pr_ray_directions.append(preds[i]["ray_directions"]) + pr_pose_quats.append(pr_pose_quats_in_view0) + + # Get the predicted metric scale points + if "metric_scaling_factor" in preds[i].keys(): + # Detach the raw predicted points so that the scale loss is only applied to the scaling factor + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.detach() + * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1) + ) + else: + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.clone() + ) + metric_pr_pts_to_compute_scale.append( + curr_view_metric_pr_pts_to_compute_scale + ) + + if dist_clip is not None: + # Points that are too far-away == invalid + for i in range(n_views): + dis = no_norm_gt_pts[i].norm(dim=-1) + valid_masks[i] = valid_masks[i] & (dis <= dist_clip) + + # Initialize normalized tensors + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam] + gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth] + gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans] + + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam] + pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth] + pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans] + + # Normalize the predicted points if specified + if self.norm_predictions: + pr_normalization_output = normalize_multiple_pointclouds( + no_norm_pr_pts, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_pts_norm = pr_normalization_output[:-1] + pr_norm_factor = pr_normalization_output[-1] + + # Normalize the ground truth points + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + gt_pts_norm = gt_normalization_output[:-1] + gt_norm_factor = gt_normalization_output[-1] + + for i in range(n_views): + if self.norm_predictions: + # Assign the normalized predictions + pr_pts[i] = pr_pts_norm[i] + pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor + pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor + pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0] + else: + pr_pts[i] = no_norm_pr_pts[i] + pr_pts_cam[i] = no_norm_pr_pts_cam[i] + pr_depth[i] = no_norm_pr_depth[i] + pr_pose_trans[i] = no_norm_pr_pose_trans[i] + # Assign the normalized ground truth quantities + gt_pts[i] = gt_pts_norm[i] + gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor + gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor + gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0] + + # Get the mask indicating ground truth metric scale quantities + metric_scale_mask = batch[0]["is_metric_scale"] + valid_gt_norm_factor_mask = ( + gt_norm_factor[:, 0, 0, 0] > 1e-8 + ) # Mask out cases where depth for all views is invalid + valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask + + if valid_metric_scale_mask.any(): + # Compute the scale norm factor using the predicted metric scale points + metric_pr_normalization_output = normalize_multiple_pointclouds( + metric_pr_pts_to_compute_scale, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_metric_norm_factor = metric_pr_normalization_output[-1] + + # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities + gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask] + pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask] + else: + gt_metric_norm_factor = None + pr_metric_norm_factor = None + + # Get ambiguous masks + ambiguous_masks = [] + for i in range(n_views): + ambiguous_masks.append( + (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) + ) + + # Pack into info dicts + gt_info = [] + pred_info = [] + for i in range(n_views): + gt_info.append( + { + "ray_directions": gt_ray_directions[i], + self.depth_type_for_loss: gt_depth[i], + "pose_trans": gt_pose_trans[i], + "pose_quats": gt_pose_quats[i], + "pts3d": gt_pts[i], + "pts3d_cam": gt_pts_cam[i], + } + ) + pred_info.append( + { + "ray_directions": pr_ray_directions[i], + self.depth_type_for_loss: pr_depth[i], + "pose_trans": pr_pose_trans[i], + "pose_quats": pr_pose_quats[i], + "pts3d": pr_pts[i], + "pts3d_cam": pr_pts_cam[i], + } + ) + + return ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) + + def compute_loss(self, batch, preds, **kw): + ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + + # Mask out samples in the batch where the gt depth validity mask is entirely zero + valid_norm_factor_masks = [ + mask.sum(dim=(1, 2)) > 0 for mask in valid_masks + ] # List of (B,) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixel as "valid" pixels... + valid_masks = [ + mask | ambig_mask + for mask, ambig_mask in zip(valid_masks, ambiguous_masks) + ] + + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + cam_pts3d_losses = [] + if self.compute_world_frame_points_loss: + pts3d_losses = [] + + for i in range(n_views): + # Get the predicted dense quantities + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_ray_directions = pred_info[i]["ray_directions"] + gt_ray_directions = gt_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] + gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] + pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] + gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] + if self.compute_world_frame_points_loss: + pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] + gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape + gt_ray_directions = gt_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + pred_ray_directions = pred_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] + gt_depth = gt_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + pred_depth = pred_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] + gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) + pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( + batch_size, -1, cam_pts_dim + ) + if self.compute_world_frame_points_loss: + pts_dim = gt_info[i]["pts3d"].shape[-1] + gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) + pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space for depth if specified + if self.loss_in_log: + gt_depth = apply_log_to_norm(gt_depth) + pred_depth = apply_log_to_norm(pred_depth) + gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) + pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) + if self.compute_world_frame_points_loss: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_pts3d = apply_log_to_norm(pred_pts3d) + + if self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions, gt_ray_directions, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute camera frame point loss + cam_pts3d_loss = self.criterion( + pred_cam_pts3d, gt_cam_pts3d, factor="points" + ) + cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight + cam_pts3d_losses.append(cam_pts3d_loss) + + if self.compute_world_frame_points_loss: + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + depth_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + depth_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + pts3d_losses[i], + ) + + # Compute the scale loss + if gt_metric_norm_factor is not None: + if self.loss_in_log: + gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) + pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) + scale_loss = ( + self.criterion( + pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" + ) + * self.scale_loss_weight + ) + else: + scale_loss = None + + # Use helper function to generate loss terms and details + if self.compute_world_frame_points_loss: + losses_dict = { + "pts3d": { + "values": pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + } + else: + losses_dict = {} + losses_dict.update( + { + "cam_pts3d": { + "values": cam_pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": False, + "is_multi_view": True, + }, + "scale": { + "values": scale_loss, + "use_mask": False, + "is_multi_view": False, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D): + """ + Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale. + """ + + def __init__( + self, + criterion, + norm_predictions=True, + norm_mode="avg_dis", + ambiguous_loss_value=0, + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + cam_frame_points_loss_weight=1, + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + scale_loss_weight=1, + compute_pairwise_relative_pose_loss=False, + convert_predictions_to_view0_frame=False, + compute_world_frame_points_loss=True, + world_frame_points_loss_weight=1, + apply_normal_and_gm_loss_to_synthetic_data_only=True, + normal_loss_weight=1, + gm_loss_weight=1, + ): + """ + Initialize the loss criterion for Ray Directions, Depth, Pose, Pointmaps & Scale. + Additionally computes: + (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates, + (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth, pointmaps and scale. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the + exhaustive pairwise relative poses. Default: False. + convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. + Use this if the predictions are not already in the view0 frame. Default: False. + compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. + world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. + apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. + If False, apply the normal and gm loss to all data. Default: True. + normal_loss_weight (float): Weight to use for the normal loss. Default: 1. + gm_loss_weight (float): Weight to use for the gm loss. Default: 1. + """ + super().__init__( + criterion=criterion, + norm_predictions=norm_predictions, + norm_mode=norm_mode, + ambiguous_loss_value=ambiguous_loss_value, + loss_in_log=loss_in_log, + flatten_across_image_only=flatten_across_image_only, + depth_type_for_loss=depth_type_for_loss, + cam_frame_points_loss_weight=cam_frame_points_loss_weight, + depth_loss_weight=depth_loss_weight, + ray_directions_loss_weight=ray_directions_loss_weight, + pose_quats_loss_weight=pose_quats_loss_weight, + pose_trans_loss_weight=pose_trans_loss_weight, + scale_loss_weight=scale_loss_weight, + compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss, + convert_predictions_to_view0_frame=convert_predictions_to_view0_frame, + compute_world_frame_points_loss=compute_world_frame_points_loss, + world_frame_points_loss_weight=world_frame_points_loss_weight, + ) + self.apply_normal_and_gm_loss_to_synthetic_data_only = ( + apply_normal_and_gm_loss_to_synthetic_data_only + ) + self.normal_loss_weight = normal_loss_weight + self.gm_loss_weight = gm_loss_weight + + def compute_loss(self, batch, preds, **kw): + ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + + # Mask out samples in the batch where the gt depth validity mask is entirely zero + valid_norm_factor_masks = [ + mask.sum(dim=(1, 2)) > 0 for mask in valid_masks + ] # List of (B,) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixel as "valid" pixels... + valid_masks = [ + mask | ambig_mask + for mask, ambig_mask in zip(valid_masks, ambiguous_masks) + ] + + normal_losses = [] + gradient_matching_losses = [] + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + cam_pts3d_losses = [] + if self.compute_world_frame_points_loss: + pts3d_losses = [] + + for i in range(n_views): + # Get the camera frame points, log space depth_z & valid masks + pred_local_pts3d = pred_info[i]["pts3d_cam"] + pred_depth_z = pred_local_pts3d[..., 2:] + pred_depth_z = apply_log_to_norm(pred_depth_z) + gt_local_pts3d = gt_info[i]["pts3d_cam"] + gt_depth_z = gt_local_pts3d[..., 2:] + gt_depth_z = apply_log_to_norm(gt_depth_z) + valid_mask_for_normal_gm_loss = valid_masks[i].clone() + + # Update the validity mask for normal & gm loss based on the synthetic data mask if required + if self.apply_normal_and_gm_loss_to_synthetic_data_only: + synthetic_mask = batch[i]["is_synthetic"] # (B, ) + synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) + synthetic_mask = synthetic_mask.expand( + -1, pred_depth_z.shape[1], pred_depth_z.shape[2] + ) # (B, H, W) + valid_mask_for_normal_gm_loss = ( + valid_mask_for_normal_gm_loss & synthetic_mask + ) + + # Compute the normal loss + normal_loss = compute_normal_loss( + pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() + ) + normal_loss = normal_loss * self.normal_loss_weight + normal_losses.append(normal_loss) + + # Compute the gradient matching loss + gradient_matching_loss = compute_gradient_matching_loss( + pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() + ) + gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight + gradient_matching_losses.append(gradient_matching_loss) + + # Get the predicted dense quantities + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks and compute the metrics + pred_ray_directions = pred_info[i]["ray_directions"] + gt_ray_directions = gt_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] + gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] + pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] + gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] + if self.compute_world_frame_points_loss: + pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] + gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W and compute the metrics + batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape + gt_ray_directions = gt_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + pred_ray_directions = pred_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] + gt_depth = gt_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + pred_depth = pred_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] + gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) + pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( + batch_size, -1, cam_pts_dim + ) + if self.compute_world_frame_points_loss: + pts_dim = gt_info[i]["pts3d"].shape[-1] + gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) + pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space for depth if specified + if self.loss_in_log: + gt_depth = apply_log_to_norm(gt_depth) + pred_depth = apply_log_to_norm(pred_depth) + gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) + pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) + if self.compute_world_frame_points_loss: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_pts3d = apply_log_to_norm(pred_pts3d) + + if self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions, gt_ray_directions, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute camera frame point loss + cam_pts3d_loss = self.criterion( + pred_cam_pts3d, gt_cam_pts3d, factor="points" + ) + cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight + cam_pts3d_losses.append(cam_pts3d_loss) + + if self.compute_world_frame_points_loss: + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + depth_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + depth_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + pts3d_losses[i], + ) + + # Compute the scale loss + if gt_metric_norm_factor is not None: + if self.loss_in_log: + gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) + pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) + scale_loss = ( + self.criterion( + pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" + ) + * self.scale_loss_weight + ) + else: + scale_loss = None + + # Use helper function to generate loss terms and details + if self.compute_world_frame_points_loss: + losses_dict = { + "pts3d": { + "values": pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + } + else: + losses_dict = {} + losses_dict.update( + { + "cam_pts3d": { + "values": cam_pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": False, + "is_multi_view": True, + }, + "scale": { + "values": scale_loss, + "use_mask": False, + "is_multi_view": False, + }, + "normal": { + "values": normal_losses, + "use_mask": False, + "is_multi_view": True, + }, + "gradient_matching": { + "values": gradient_matching_losses, + "use_mask": False, + "is_multi_view": True, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class DisentangledFactoredGeometryScaleRegr3D(Criterion, MultiLoss): + """ + Disentangled Regression Loss for Factored Geometry & Scale. + """ + + def __init__( + self, + criterion, + norm_predictions=True, + norm_mode="avg_dis", + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + scale_loss_weight=1, + ): + """ + Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale. + It isolates/disentangles the contribution of each factor to the final task of 3D reconstruction. + All the losses are in the same space where the loss for each factor is computed by constructing world-frame pointmaps. + This sidesteps the difficulty of finding a proper weighting. + For insance, for predicted rays, the GT depth & pose is used to construct the predicted world-frame pointmaps on which the loss is computed. + Inspired by https://openaccess.thecvf.com/content_ICCV_2019/papers/Simonelli_Disentangling_Monocular_3D_Object_Detection_ICCV_2019_paper.pdf + + The pixel-level losses are computed in the following order: + (1) depth, (2) ray directions, (3) pose quats, (4) pose trans, (5) scale. + The predicited scene representation is always normalized w.r.t. the frame of view0. + Loss is applied between the predicted metric scale and the ground truth metric scale. + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth, pointmaps and scale. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + """ + super().__init__(criterion) + self.norm_predictions = norm_predictions + self.norm_mode = norm_mode + self.loss_in_log = loss_in_log + self.flatten_across_image_only = flatten_across_image_only + self.depth_type_for_loss = depth_type_for_loss + assert self.depth_type_for_loss in [ + "depth_along_ray", + "depth_z", + ], "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" + self.depth_loss_weight = depth_loss_weight + self.ray_directions_loss_weight = ray_directions_loss_weight + self.pose_quats_loss_weight = pose_quats_loss_weight + self.pose_trans_loss_weight = pose_trans_loss_weight + self.scale_loss_weight = scale_loss_weight + + def get_all_info(self, batch, preds, dist_clip=None): + """ + Function to get all the information needed to compute the loss. + Returns all quantities normalized w.r.t. camera of view0. + """ + n_views = len(batch) + + # Everything is normalized w.r.t. camera of view0 + # Intialize lists to store data for all views + # Ground truth quantities + in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) + no_norm_gt_pts = [] + no_norm_gt_pts_cam = [] + no_norm_gt_depth = [] + no_norm_gt_pose_trans = [] + valid_masks = [] + gt_ray_directions = [] + gt_pose_quats = [] + # Predicted quantities + no_norm_pr_pts = [] + no_norm_pr_pts_cam = [] + no_norm_pr_depth = [] + no_norm_pr_pose_trans = [] + pr_ray_directions = [] + pr_pose_quats = [] + metric_pr_pts_to_compute_scale = [] + + # Get ground truth & prediction info for all views + for i in range(n_views): + # Get the ground truth + no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) + valid_masks.append(batch[i]["valid_mask"].clone()) + no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"]) + gt_ray_directions.append(batch[i]["ray_directions_cam"]) + if self.depth_type_for_loss == "depth_along_ray": + no_norm_gt_depth.append(batch[i]["depth_along_ray"]) + elif self.depth_type_for_loss == "depth_z": + no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:]) + if i == 0: + # For view0, initialize identity pose + gt_pose_quats.append( + torch.tensor( + [0, 0, 0, 1], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + no_norm_gt_pose_trans.append( + torch.tensor( + [0, 0, 0], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + else: + # For other views, transform pose to view0's frame + gt_pose_quats_world = batch[i]["camera_pose_quats"] + no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"] + gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = ( + transform_pose_using_quats_and_trans_2_to_1( + batch[0]["camera_pose_quats"], + batch[0]["camera_pose_trans"], + gt_pose_quats_world, + no_norm_gt_pose_trans_world, + ) + ) + gt_pose_quats.append(gt_pose_quats_in_view0) + no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0) + + # Get predictions for normalized loss + if self.depth_type_for_loss == "depth_along_ray": + curr_view_no_norm_depth = preds[i]["depth_along_ray"] + elif self.depth_type_for_loss == "depth_z": + curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:] + if "metric_scaling_factor" in preds[i].keys(): + # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans + # This detaches the predicted metric scaling factor from the geometry based loss + curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_pr_pose_trans = ( + preds[i]["cam_trans"] / preds[i]["metric_scaling_factor"] + ) + else: + curr_view_no_norm_pr_pts = preds[i]["pts3d"] + curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] + curr_view_no_norm_depth = curr_view_no_norm_depth + curr_view_no_norm_pr_pose_trans = preds[i]["cam_trans"] + no_norm_pr_pts.append(curr_view_no_norm_pr_pts) + no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam) + no_norm_pr_depth.append(curr_view_no_norm_depth) + no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans) + pr_ray_directions.append(preds[i]["ray_directions"]) + pr_pose_quats.append(preds[i]["cam_quats"]) + + # Get the predicted metric scale points + if "metric_scaling_factor" in preds[i].keys(): + # Detach the raw predicted points so that the scale loss is only applied to the scaling factor + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.detach() + * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1) + ) + else: + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.clone() + ) + metric_pr_pts_to_compute_scale.append( + curr_view_metric_pr_pts_to_compute_scale + ) + + if dist_clip is not None: + # Points that are too far-away == invalid + for i in range(n_views): + dis = no_norm_gt_pts[i].norm(dim=-1) + valid_masks[i] = valid_masks[i] & (dis <= dist_clip) + + # Initialize normalized tensors + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam] + gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth] + gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans] + + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam] + pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth] + pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans] + + # Normalize the predicted points if specified + if self.norm_predictions: + pr_normalization_output = normalize_multiple_pointclouds( + no_norm_pr_pts, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_pts_norm = pr_normalization_output[:-1] + pr_norm_factor = pr_normalization_output[-1] + + # Normalize the ground truth points + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + gt_pts_norm = gt_normalization_output[:-1] + gt_norm_factor = gt_normalization_output[-1] + + for i in range(n_views): + if self.norm_predictions: + # Assign the normalized predictions + pr_pts[i] = pr_pts_norm[i] + pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor + pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor + pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0] + else: + pr_pts[i] = no_norm_pr_pts[i] + pr_pts_cam[i] = no_norm_pr_pts_cam[i] + pr_depth[i] = no_norm_pr_depth[i] + pr_pose_trans[i] = no_norm_pr_pose_trans[i] + # Assign the normalized ground truth quantities + gt_pts[i] = gt_pts_norm[i] + gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor + gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor + gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0] + + # Get the mask indicating ground truth metric scale quantities + metric_scale_mask = batch[0]["is_metric_scale"] + valid_gt_norm_factor_mask = ( + gt_norm_factor[:, 0, 0, 0] > 1e-8 + ) # Mask out cases where depth for all views is invalid + valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask + + if valid_metric_scale_mask.any(): + # Compute the scale norm factor using the predicted metric scale points + metric_pr_normalization_output = normalize_multiple_pointclouds( + metric_pr_pts_to_compute_scale, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_metric_norm_factor = metric_pr_normalization_output[-1] + + # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities + gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask] + pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask] + else: + gt_metric_norm_factor = None + pr_metric_norm_factor = None + + # Get ambiguous masks + ambiguous_masks = [] + for i in range(n_views): + ambiguous_masks.append( + (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) + ) + + # Pack into info dicts + gt_info = [] + pred_info = [] + for i in range(n_views): + gt_info.append( + { + "ray_directions": gt_ray_directions[i], + self.depth_type_for_loss: gt_depth[i], + "pose_trans": gt_pose_trans[i], + "pose_quats": gt_pose_quats[i], + "pts3d": gt_pts[i], + "pts3d_cam": gt_pts_cam[i], + } + ) + pred_info.append( + { + "ray_directions": pr_ray_directions[i], + self.depth_type_for_loss: pr_depth[i], + "pose_trans": pr_pose_trans[i], + "pose_quats": pr_pose_quats[i], + "pts3d": pr_pts[i], + "pts3d_cam": pr_pts_cam[i], + } + ) + + return ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) + + def compute_loss(self, batch, preds, **kw): + ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + + for i in range(n_views): + # Get the GT factored quantities for the current view + gt_pts3d = gt_info[i]["pts3d"] + gt_ray_directions = gt_info[i]["ray_directions"] + gt_depth = gt_info[i][self.depth_type_for_loss] + gt_pose_trans = gt_info[i]["pose_trans"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Get the predicted factored quantities for the current view + pred_ray_directions = pred_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss] + pred_pose_trans = pred_info[i]["pose_trans"] + pred_pose_quats = pred_info[i]["pose_quats"] + + # Get the predicted world-frame pointmaps using the different factors + if self.depth_type_for_loss == "depth_along_ray": + pred_ray_directions_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + pred_ray_directions, + gt_depth, + gt_pose_trans, + gt_pose_quats, + ) + ) + pred_depth_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + pred_depth, + gt_pose_trans, + gt_pose_quats, + ) + ) + pred_pose_trans_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + gt_depth, + pred_pose_trans, + gt_pose_quats, + ) + ) + pred_pose_quats_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + gt_depth, + gt_pose_trans, + pred_pose_quats, + ) + ) + else: + raise NotImplementedError + + # Mask out the valid quantities as required + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]] + pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]] + pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]] + pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]] + gt_pts3d = gt_pts3d[valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, pts_dim = gt_pts3d.shape + pred_ray_directions_pts3d = pred_ray_directions_pts3d.view( + batch_size, -1, pts_dim + ) + pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim) + pred_pose_trans_pts3d = pred_pose_trans_pts3d.view( + batch_size, -1, pts_dim + ) + pred_pose_quats_pts3d = pred_pose_quats_pts3d.view( + batch_size, -1, pts_dim + ) + gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space if specified + if self.loss_in_log: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d) + pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d) + pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d) + pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + pose_quats_loss = self.criterion( + pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats" + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute the scale loss + if gt_metric_norm_factor is not None: + if self.loss_in_log: + gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) + pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) + scale_loss = ( + self.criterion( + pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" + ) + * self.scale_loss_weight + ) + else: + scale_loss = None + + # Use helper function to generate loss terms and details + losses_dict = {} + losses_dict.update( + { + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": True, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": True, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": True, + "is_multi_view": True, + }, + "scale": { + "values": scale_loss, + "use_mask": False, + "is_multi_view": False, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class DisentangledFactoredGeometryScaleRegr3DPlusNormalGMLoss( + DisentangledFactoredGeometryScaleRegr3D +): + """ + Disentangled Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale. + """ + + def __init__( + self, + criterion, + norm_predictions=True, + norm_mode="avg_dis", + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + scale_loss_weight=1, + apply_normal_and_gm_loss_to_synthetic_data_only=True, + normal_loss_weight=1, + gm_loss_weight=1, + ): + """ + Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale. + See parent class (DisentangledFactoredGeometryScaleRegr3D) for more details. + Additionally computes: + (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates, + (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth, pointmaps and scale. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. + If False, apply the normal and gm loss to all data. Default: True. + normal_loss_weight (float): Weight to use for the normal loss. Default: 1. + gm_loss_weight (float): Weight to use for the gm loss. Default: 1. + """ + super().__init__( + criterion=criterion, + norm_predictions=norm_predictions, + norm_mode=norm_mode, + loss_in_log=loss_in_log, + flatten_across_image_only=flatten_across_image_only, + depth_type_for_loss=depth_type_for_loss, + depth_loss_weight=depth_loss_weight, + ray_directions_loss_weight=ray_directions_loss_weight, + pose_quats_loss_weight=pose_quats_loss_weight, + pose_trans_loss_weight=pose_trans_loss_weight, + scale_loss_weight=scale_loss_weight, + ) + self.apply_normal_and_gm_loss_to_synthetic_data_only = ( + apply_normal_and_gm_loss_to_synthetic_data_only + ) + self.normal_loss_weight = normal_loss_weight + self.gm_loss_weight = gm_loss_weight + + def compute_loss(self, batch, preds, **kw): + ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + + normal_losses = [] + gradient_matching_losses = [] + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + + for i in range(n_views): + # Get the camera frame points, log space depth_z & valid masks + pred_local_pts3d = pred_info[i]["pts3d_cam"] + pred_depth_z = pred_local_pts3d[..., 2:] + pred_depth_z = apply_log_to_norm(pred_depth_z) + gt_local_pts3d = gt_info[i]["pts3d_cam"] + gt_depth_z = gt_local_pts3d[..., 2:] + gt_depth_z = apply_log_to_norm(gt_depth_z) + valid_mask_for_normal_gm_loss = valid_masks[i].clone() + + # Update the validity mask for normal & gm loss based on the synthetic data mask if required + if self.apply_normal_and_gm_loss_to_synthetic_data_only: + synthetic_mask = batch[i]["is_synthetic"] # (B, ) + synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) + synthetic_mask = synthetic_mask.expand( + -1, pred_depth_z.shape[1], pred_depth_z.shape[2] + ) # (B, H, W) + valid_mask_for_normal_gm_loss = ( + valid_mask_for_normal_gm_loss & synthetic_mask + ) + + # Compute the normal loss + normal_loss = compute_normal_loss( + pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() + ) + normal_loss = normal_loss * self.normal_loss_weight + normal_losses.append(normal_loss) + + # Compute the gradient matching loss + gradient_matching_loss = compute_gradient_matching_loss( + pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() + ) + gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight + gradient_matching_losses.append(gradient_matching_loss) + + # Get the GT factored quantities for the current view + gt_pts3d = gt_info[i]["pts3d"] + gt_ray_directions = gt_info[i]["ray_directions"] + gt_depth = gt_info[i][self.depth_type_for_loss] + gt_pose_trans = gt_info[i]["pose_trans"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Get the predicted factored quantities for the current view + pred_ray_directions = pred_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss] + pred_pose_trans = pred_info[i]["pose_trans"] + pred_pose_quats = pred_info[i]["pose_quats"] + + # Get the predicted world-frame pointmaps using the different factors + if self.depth_type_for_loss == "depth_along_ray": + pred_ray_directions_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + pred_ray_directions, + gt_depth, + gt_pose_trans, + gt_pose_quats, + ) + ) + pred_depth_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + pred_depth, + gt_pose_trans, + gt_pose_quats, + ) + ) + pred_pose_trans_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + gt_depth, + pred_pose_trans, + gt_pose_quats, + ) + ) + pred_pose_quats_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + gt_depth, + gt_pose_trans, + pred_pose_quats, + ) + ) + else: + raise NotImplementedError + + # Mask out the valid quantities as required + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]] + pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]] + pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]] + pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]] + gt_pts3d = gt_pts3d[valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, pts_dim = gt_pts3d.shape + pred_ray_directions_pts3d = pred_ray_directions_pts3d.view( + batch_size, -1, pts_dim + ) + pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim) + pred_pose_trans_pts3d = pred_pose_trans_pts3d.view( + batch_size, -1, pts_dim + ) + pred_pose_quats_pts3d = pred_pose_quats_pts3d.view( + batch_size, -1, pts_dim + ) + gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space if specified + if self.loss_in_log: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d) + pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d) + pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d) + pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + pose_quats_loss = self.criterion( + pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats" + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute the scale loss + if gt_metric_norm_factor is not None: + if self.loss_in_log: + gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) + pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) + scale_loss = ( + self.criterion( + pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" + ) + * self.scale_loss_weight + ) + else: + scale_loss = None + + # Use helper function to generate loss terms and details + losses_dict = {} + losses_dict.update( + { + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": True, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": True, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": True, + "is_multi_view": True, + }, + "scale": { + "values": scale_loss, + "use_mask": False, + "is_multi_view": False, + }, + "normal": { + "values": normal_losses, + "use_mask": False, + "is_multi_view": True, + }, + "gradient_matching": { + "values": gradient_matching_losses, + "use_mask": False, + "is_multi_view": True, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) diff --git a/mapanything/train/profile_dataloading.py b/mapanything/train/profile_dataloading.py new file mode 100644 index 0000000000000000000000000000000000000000..881eb21fe131d37d6331a0f503cbea4d5bac1eda --- /dev/null +++ b/mapanything/train/profile_dataloading.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Debug script to profile dataloading for MapAnything training. + +This script measures and analyzes the performance of data loading operations +for MapAnything training workflows. It simulates the training process without +actual model training to isolate and profile the data loading components. +""" + +import datetime +import json +import os +import time +from pathlib import Path +from typing import Sized + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter + +import mapanything.utils.train_tools as train_tools +from mapanything.datasets import get_test_data_loader, get_train_data_loader +from mapanything.datasets.base.base_dataset import view_name + +# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) +if hasattr(torch.backends.cuda, "matmul") and hasattr( + torch.backends.cuda.matmul, "allow_tf32" +): + torch.backends.cuda.matmul.allow_tf32 = True + + +def profile_dataloading(args): + """ + Main profiling function that simulates the training process to measure data loading performance. + + This function initializes the distributed environment, sets up datasets and data loaders, + and runs through training epochs to profile the data loading operations. It measures + the time taken for data loading without performing actual model training or optimization. + + In this simulation, an epoch represents a complete pass through a chunk of the dataset. + + Args: + args: Configuration object containing all parameters including: + - dataset: Dataset configuration (train_dataset, test_dataset, num_workers) + - train_params: Training parameters (batch_size, epochs, seed, etc.) + - distributed: Distributed training configuration + - output_dir: Directory for saving logs and profiling results + """ + # Initialize distributed training if required + train_tools.init_distributed_mode(args.distributed) + global_rank = train_tools.get_rank() + world_size = train_tools.get_world_size() # noqa + + # Init output directory and device + print("output_dir: " + args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(", ", ",\n")) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # Fix the seed + seed = args.train_params.seed + train_tools.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = not args.train_params.disable_cudnn_benchmark + + # Datasets and Dataloaders + print("Building train dataset {:s}".format(args.dataset.train_dataset)) + data_loader_train = build_dataset( + dataset=args.dataset.train_dataset, + num_workers=args.dataset.num_workers, + test=False, + max_num_of_imgs_per_gpu=args.train_params.max_num_of_imgs_per_gpu, + ) + print("Building test dataset {:s}".format(args.dataset.test_dataset)) + test_batch_size = 2 * ( + args.train_params.max_num_of_imgs_per_gpu // args.dataset.num_views + ) # Since we don't have any backward overhead + data_loader_test = { + dataset.split("(")[0]: build_dataset( + dataset=dataset, + num_workers=args.dataset.num_workers, + test=True, + batch_size=test_batch_size, + ) + for dataset in args.dataset.test_dataset.split("+") + if "(" in dataset + } + + def write_log_stats(epoch, train_stats, test_stats): + """ + Writes profiling statistics to log files and TensorBoard. + + This function collects metrics from the training and testing phases and writes them + to log files and TensorBoard for visualization and analysis. It only executes on the + main process in a distributed setting. + + Args: + epoch: int, current epoch number + train_stats: dict, containing training metrics and timing information + test_stats: dict, containing testing metrics for each test dataset + """ + if train_tools.is_main_process(): + if log_writer is not None: + log_writer.flush() + + log_stats = dict( + epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()} + ) + for test_name in data_loader_test: + if test_name not in test_stats: + continue + log_stats.update( + {test_name + "_" + k: v for k, v in test_stats[test_name].items()} + ) + + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training for {args.train_params.epochs} epochs") + start_time = time.time() + train_stats = test_stats = {} + args.train_params.start_epoch = 0 + for epoch in range(args.train_params.start_epoch, args.train_params.epochs + 1): + # Save more stuff + write_log_stats(epoch, train_stats, test_stats) + + if epoch >= args.train_params.epochs: + break # exit after writing last test to disk + + # Train + train_stats = train_one_epoch( + data_loader_train, + device, + epoch, + log_writer=log_writer, + args=args, + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + +def build_dataset( + dataset, num_workers, test, batch_size=None, max_num_of_imgs_per_gpu=None +): + """ + Builds data loaders for training or testing. + + Args: + dataset: Dataset specification string. + num_workers: Number of worker processes for data loading. + test: Boolean flag indicating whether this is a test dataset. + batch_size: Number of samples per batch. Defaults to None. Used only for testing. + max_num_of_imgs_per_gpu: Maximum number of images per GPU. Defaults to None. Used only for training. + + Returns: + DataLoader: PyTorch DataLoader configured for the specified dataset. + """ + split = ["Train", "Test"][test] + print(f"Building {split} Data loader for dataset: ", dataset) + if test: + assert batch_size is not None, ( + "batch_size must be specified for testing dataloader" + ) + loader = get_test_data_loader( + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_mem=True, + shuffle=False, + drop_last=False, + ) + else: + assert max_num_of_imgs_per_gpu is not None, ( + "max_num_of_imgs_per_gpu must be specified for training dataloader" + ) + loader = get_train_data_loader( + dataset=dataset, + max_num_of_imgs_per_gpu=max_num_of_imgs_per_gpu, + num_workers=num_workers, + pin_mem=True, + shuffle=True, + drop_last=True, + ) + + print(f"{split} dataset length: ", len(loader)) + return loader + + +def train_one_epoch( + data_loader: Sized, + device: torch.device, + epoch: int, + args, + log_writer=None, +): + """ + Simulates training for one epoch to profile data loading performance. + + This function runs through a single epoch, simulating the data loading and device transfer + operations that would occur during actual training. It measures and logs the time taken + for these operations without performing actual model training. + + Args: + data_loader: Sized, DataLoader providing the training data + device: torch.device, device to transfer data to (CPU or GPU) + epoch: int, current epoch number + args: object, configuration object containing training parameters including: + - train_params.print_freq: frequency of logging during the epoch + log_writer: Optional[SummaryWriter], TensorBoard SummaryWriter for logging metrics + + Returns: + dict: Dictionary containing profiling metrics averaged over the epoch + """ + metric_logger = train_tools.MetricLogger(delimiter=" ") + header = "Epoch: [{}]".format(epoch) + + if log_writer is not None: + print("log_dir: {}".format(log_writer.log_dir)) + + if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"): + data_loader.dataset.set_epoch(epoch) + if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"): + data_loader.sampler.set_epoch(epoch) + if hasattr(data_loader, "batch_sampler") and hasattr( + data_loader.batch_sampler, "set_epoch" + ): + data_loader.batch_sampler.set_epoch(epoch) + + for data_iter_step, batch in enumerate( + metric_logger.log_every(data_loader, args.train_params.print_freq, header) + ): + epoch_f = epoch + data_iter_step / len(data_loader) + + # Simulate the device loading in loss_of_one_batch_multi_view + ignore_keys = set( + [ + "depthmap", + "dataset", + "label", + "instance", + "idx", + "true_shape", + "rng", + "data_norm_type", + ] + ) + for view in batch: + for name in view.keys(): + if name in ignore_keys: + continue + view[name] = view[name].to(device, non_blocking=True) + + local_rank = train_tools.get_rank() + n_views = len(batch) + batch_shape = batch[0]["img"].shape + first_sample_name = view_name(batch[0], batch_index=0) + print( + f"Rank: {local_rank}, Num views: {n_views}, Batch Shape: {batch_shape}, First Sample Name: {first_sample_name}", + force=True, + ) + + del batch + + metric_logger.update(epoch=epoch_f) + metric_logger.update(loss=0) + + # # Gather the stats from all processes + # metric_logger.synchronize_between_processes() + # print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/mapanything/train/training.py b/mapanything/train/training.py new file mode 100644 index 0000000000000000000000000000000000000000..109057bc67022ead3634ada41458ed77b7403bd7 --- /dev/null +++ b/mapanything/train/training.py @@ -0,0 +1,664 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Training Code for MapAnything. + +References: +DUSt3R: https://github.com/naver/dust3r +""" + +import datetime +import json +import math +import os +import pickle +import sys +import time +from collections import defaultdict +from pathlib import Path +from typing import Sized + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter + +import mapanything.utils.train_tools as train_tools +from mapanything.datasets import get_test_data_loader, get_train_data_loader +from mapanything.models import init_model +from mapanything.train.losses import * # noqa +from mapanything.utils.inference import loss_of_one_batch_multi_view +from mapanything.utils.train_tools import NativeScalerWithGradNormCount as NativeScaler + +# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) +if hasattr(torch.backends.cuda, "matmul") and hasattr( + torch.backends.cuda.matmul, "allow_tf32" +): + torch.backends.cuda.matmul.allow_tf32 = True + + +def train(args): + """ + Main training function that handles the entire training process. + + This function initializes the distributed training environment, sets up datasets, + initializes the model, optimizer, and loss functions, and manages the training + and evaluation loop across multiple epochs. + + In this training, an epoch is just a chunk of the entire dataset. + + Args: + args: Configuration object containing all training parameters including + dataset configs, model configs, training parameters, and loss functions. + """ + # Initialize distributed training if required + train_tools.init_distributed_mode(args.distributed) + global_rank = train_tools.get_rank() + world_size = train_tools.get_world_size() # noqa + + # Init output directory and device + print("output_dir: " + args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(", ", ",\n")) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # Fix the seed + seed = args.train_params.seed + train_tools.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = not args.train_params.disable_cudnn_benchmark + + # Datasets and Dataloaders + print("Building train dataset {:s}".format(args.dataset.train_dataset)) + data_loader_train = build_dataset( + dataset=args.dataset.train_dataset, + num_workers=args.dataset.num_workers, + test=False, + max_num_of_imgs_per_gpu=args.train_params.max_num_of_imgs_per_gpu, + ) + print("Building test dataset {:s}".format(args.dataset.test_dataset)) + test_batch_size = 2 * ( + args.train_params.max_num_of_imgs_per_gpu // args.dataset.num_views + ) # Since we don't have any backward overhead + data_loader_test = { + dataset.split("(")[0]: build_dataset( + dataset=dataset, + num_workers=args.dataset.num_workers, + test=True, + batch_size=test_batch_size, + ) + for dataset in args.dataset.test_dataset.split("+") + if "(" in dataset + } + + # Load Model + if global_rank == 0: + model = init_model( + args.model.model_str, + args.model.model_config, + torch_hub_force_reload=args.model.torch_hub_force_reload, + ) + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Make sure the model is initialized before proceeding + if global_rank != 0: + model = init_model( + args.model.model_str, args.model.model_config, torch_hub_force_reload=False + ) + model.to(device) # Move model to device + model_without_ddp = model + print("Model = %s" % str(model_without_ddp)) + + # Criterion + print(f">> Creating train criterion = {args.loss.train_criterion}") + train_criterion = eval(args.loss.train_criterion).to(device) + print( + f">> Creating test criterion = {args.loss.test_criterion or args.loss.train_criterion}" + ) + test_criterion = eval(args.loss.test_criterion or args.loss.train_criterion).to( + device + ) + + # Load pretrained model if provided + if args.model.pretrained: + print("Loading pretrained: ", args.model.pretrained) + ckpt = torch.load( + args.model.pretrained, map_location=device, weights_only=False + ) + print(model.load_state_dict(ckpt["model"], strict=False)) + del ckpt # in case it occupies memory + + # Init model for DDP training + if args.distributed.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[args.distributed.gpu], + find_unused_parameters=True, + static_graph=False, + ) + model_without_ddp = model.module + + # Optimizer and loss scaler for gradient accumulation + # Following timm: set wd as 0 for bias and norm layers + param_groups, param_groups_name_to_idx_map, param_groups_idx_to_name_map = ( + train_tools.get_parameter_groups( + model_without_ddp, + args.train_params.lr, + args.train_params.weight_decay, + submodule_configs=args.train_params.submodule_configs, + warn_not_in_submodule=args.train_params.warn_not_in_submodule, + ) + ) + optimizer = torch.optim.AdamW( + param_groups, lr=args.train_params.lr, betas=(0.9, 0.95) + ) + print(optimizer) + loss_scaler = NativeScaler() + + def write_log_stats(epoch, train_stats, test_stats): + """ + Writes training and testing statistics to log files and TensorBoard. + + Args: + epoch: Current epoch number. + train_stats: Dictionary containing training metrics. + test_stats: Dictionary containing testing metrics for each test dataset. + """ + if train_tools.is_main_process(): + if log_writer is not None: + log_writer.flush() + + log_stats = dict( + epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()} + ) + for test_name in data_loader_test: + if test_name not in test_stats: + continue + log_stats.update( + {test_name + "_" + k: v for k, v in test_stats[test_name].items()} + ) + + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + def save_model(epoch, fname, best_so_far): + """ + Saves model checkpoint to disk. + + Args: + epoch: Current epoch number. + fname: Filename or identifier for the checkpoint. + best_so_far: Best validation metric achieved so far. + """ + train_tools.save_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + fname=fname, + best_so_far=best_so_far, + ) + + # Resume from a checkpoint if needed + last_ckpt_fname = os.path.join(args.output_dir, "checkpoint-last.pth") + if args.train_params.resume and os.path.isfile(last_ckpt_fname): + args.train_params.resume_ckpt = last_ckpt_fname + else: + args.train_params.resume_ckpt = None + best_so_far = train_tools.load_model( + train_args=args.train_params, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + ) + if best_so_far is None: + best_so_far = float("inf") + + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training for {args.train_params.epochs} epochs") + start_time = time.time() + train_stats = test_stats = {} + for epoch in range(args.train_params.start_epoch, args.train_params.epochs + 1): + # Save immediately the last checkpoint + if epoch > args.train_params.start_epoch: + if ( + args.train_params.save_freq + and epoch % args.train_params.save_freq == 0 + or epoch == args.train_params.epochs + ): + save_model(epoch - 1, "last", best_so_far) + + # Test on multiple datasets + new_best = False + test_stats = {} + if ( + args.train_params.eval_freq > 0 + and epoch % args.train_params.eval_freq == 0 + and epoch > 0 + ): + for test_name, testset in data_loader_test.items(): + print(f"Testing on {test_name} ...") + stats = test_one_epoch( + model, + test_criterion, + testset, + device, + epoch, + log_writer=log_writer, + args=args, + prefix=test_name, + ) + test_stats[test_name] = stats + + # Calculate average test loss median + avg_test_loss_med = np.mean( + [stats["loss_med"] for stats in test_stats.values()] + ) + test_stats["Average Test Loss Median"] = avg_test_loss_med + # Save best + if avg_test_loss_med < best_so_far: + best_so_far = avg_test_loss_med + new_best = True + + # Save more stuff + write_log_stats(epoch, train_stats, test_stats) + + if epoch > args.train_params.start_epoch: + if args.train_params.keep_freq and epoch % args.train_params.keep_freq == 0: + save_model(epoch - 1, str(epoch), best_so_far) + if new_best: + save_model(epoch - 1, "best", best_so_far) + if epoch >= args.train_params.epochs: + break # exit after writing last test to disk + + # Train + train_stats = train_one_epoch( + model, + train_criterion, + data_loader_train, + optimizer, + device, + epoch, + loss_scaler, + log_writer=log_writer, + args=args, + param_groups_name_to_idx_map=param_groups_name_to_idx_map, + param_groups_idx_to_name_map=param_groups_idx_to_name_map, + model_without_ddp=model_without_ddp, + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + save_final_model( + args, args.train_params.epochs, model_without_ddp, best_so_far=best_so_far + ) + + +def save_final_model(args, epoch, model_without_ddp, best_so_far=None): + """ + Saves the final model checkpoint after training completion. + + Args: + args: Configuration object containing output directory information. + epoch: Current epoch number. + model_without_ddp: Model state dictionary or model instance without DistributedDataParallel wrapper. + best_so_far: Optional; Best validation metric achieved during training. + """ + output_dir = Path(args.output_dir) + checkpoint_path = output_dir / "checkpoint-final.pth" + to_save = { + "args": args, + "model": model_without_ddp + if isinstance(model_without_ddp, dict) + else model_without_ddp.cpu().state_dict(), + "epoch": epoch, + } + if best_so_far is not None: + to_save["best_so_far"] = best_so_far + print(f">> Saving model to {checkpoint_path} ...") + train_tools.save_on_master(to_save, checkpoint_path) + + +def build_dataset( + dataset, num_workers, test, batch_size=None, max_num_of_imgs_per_gpu=None +): + """ + Builds data loaders for training or testing. + + Args: + dataset: Dataset specification string. + num_workers: Number of worker processes for data loading. + test: Boolean flag indicating whether this is a test dataset. + batch_size: Number of samples per batch. Defaults to None. Used only for testing. + max_num_of_imgs_per_gpu: Maximum number of images per GPU. Defaults to None. Used only for training. + + Returns: + DataLoader: PyTorch DataLoader configured for the specified dataset. + """ + split = ["Train", "Test"][test] + print(f"Building {split} Data loader for dataset: ", dataset) + if test: + assert batch_size is not None, ( + "batch_size must be specified for testing dataloader" + ) + loader = get_test_data_loader( + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_mem=True, + shuffle=False, + drop_last=False, + ) + else: + assert max_num_of_imgs_per_gpu is not None, ( + "max_num_of_imgs_per_gpu must be specified for training dataloader" + ) + loader = get_train_data_loader( + dataset=dataset, + max_num_of_imgs_per_gpu=max_num_of_imgs_per_gpu, + num_workers=num_workers, + pin_mem=True, + shuffle=True, + drop_last=True, + ) + + print(f"{split} dataset length: ", len(loader)) + return loader + + +def train_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + data_loader: Sized, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int, + loss_scaler, + args, + log_writer=None, + param_groups_name_to_idx_map=None, + param_groups_idx_to_name_map=None, + model_without_ddp=None, +): + """ + Trains the model for one epoch. + Epoch is just a chunk of the entire dataset. + + This function handles the training loop for a single epoch, including forward/backward passes, + gradient accumulation, learning rate scheduling, and logging metrics. + + Args: + model: The neural network model to train. + criterion: Loss function to optimize. + data_loader: DataLoader providing the training data. + optimizer: Optimizer for updating model parameters. + device: Device to run training on (CPU or GPU). + epoch: Current epoch number. + loss_scaler: Scaler for gradient accumulation and mixed precision training. + args: Configuration object containing training parameters. + log_writer: Optional; TensorBoard SummaryWriter for logging. + param_groups_name_to_idx_map: Mapping from parameter group names to indices. + param_groups_idx_to_name_map: Mapping from parameter group indices to names. + model_without_ddp: Model without DistributedDataParallel wrapper for debugging. + + Returns: + dict: Dictionary containing training metrics averaged over the epoch. + """ + model.train(True) + metric_logger = train_tools.MetricLogger(delimiter=" ") + for submodule_name in param_groups_name_to_idx_map: + lr_name = f"lr_{submodule_name}" if submodule_name != "default" else "lr" + metric_logger.add_meter( + lr_name, train_tools.SmoothedValue(window_size=1, fmt="{value:.6f}") + ) + header = "Epoch: [{}]".format(epoch) + accum_iter = args.train_params.accum_iter + + if log_writer is not None: + print("log_dir: {}".format(log_writer.log_dir)) + + if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"): + data_loader.dataset.set_epoch(epoch) + if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"): + data_loader.sampler.set_epoch(epoch) + if hasattr(data_loader, "batch_sampler") and hasattr( + data_loader.batch_sampler, "set_epoch" + ): + data_loader.batch_sampler.set_epoch(epoch) + + optimizer.zero_grad() + + for data_iter_step, batch in enumerate( + metric_logger.log_every(data_loader, args.train_params.print_freq, header) + ): + n_views = len(batch) + epoch_f = epoch + data_iter_step / len(data_loader) + + # We use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + train_tools.adjust_learning_rate( + optimizer, + epoch_f, + args.train_params, + param_groups_idx_to_name_map, + args.train_params.submodule_configs, + ) + + loss_tuple = loss_of_one_batch_multi_view( + batch, + model, + criterion, + device, + use_amp=bool(args.train_params.amp), + amp_dtype=args.train_params.amp_dtype, + ret="loss", + ) + loss, loss_details = loss_tuple # criterion returns two values + if n_views > 2: + loss = loss * ( + 2 / n_views + ) # scale the loss relative to the number of views (base is 2 views) + loss_value = float(loss) + + if not math.isfinite(loss_value) or (loss_value > 1000): + print("Loss is {}, stopping training".format(loss_value), force=True) + print(f"Loss Details: {loss_details}", force=True) + print(f"Epoch: {epoch}, Data Iteration: {data_iter_step}", force=True) + # Save the current batch to the output folder for further inspection + for view_idx, view in enumerate(batch): + view_cpu = {} + for k, v in view.items(): + view_cpu[k] = v.cpu() if isinstance(v, torch.Tensor) else v + with open( + os.path.join(args.output_dir, f"batch_view_{view_idx}.pkl"), "wb" + ) as f: + pickle.dump(view_cpu, f) + # Save the model to the output folder for further inspection + checkpoint_debug_path = os.path.join( + args.output_dir, "checkpoint-debug.pth" + ) + to_save_debug = { + "args": args, + "model": ( + model_without_ddp + if isinstance(model_without_ddp, dict) + else model_without_ddp.cpu().state_dict() + ), + "epoch": epoch, + "data_iter_step": data_iter_step, + } + torch.save(to_save_debug, checkpoint_debug_path) + print(f"Saved debugging material to {args.output_dir}", force=True) + sys.exit(1) + + # Scale the loss by the number of gradient accumulation iterations + loss /= accum_iter + + # Compute the scaled gradients (also clip the gradients to max norm of 1) + gradient_norm = loss_scaler( + loss, + optimizer, + parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0, + clip_grad=1.0, + ) + + # Zero out the gradients to prepare for the next iteration of gradient descent + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + del loss + del batch + + metric_logger.update(epoch=epoch_f) + for submodule_name in param_groups_name_to_idx_map: + lr_name = f"lr_{submodule_name}" if submodule_name != "default" else "lr" + log_lr = optimizer.param_groups[ + param_groups_name_to_idx_map[submodule_name][0] + ]["lr"] + metric_logger.meters[lr_name].update(log_lr) + metric_logger.update(loss=loss_value, **loss_details) + + if (data_iter_step + 1) % accum_iter == 0 and ( + (data_iter_step + 1) % (accum_iter * args.train_params.print_freq) + ) == 0: + loss_value_reduce = train_tools.all_reduce_mean( + loss_value + ) # MUST BE EXECUTED BY ALL NODES + if log_writer is None: + continue + """ + We use epoch_1000x as the x-axis in tensorboard. + This calibrates different curves when batch size changes. + """ + epoch_1000x = int(epoch_f * 1000) + log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x) + if gradient_norm is not None: + log_writer.add_scalar("train_grad_norm", gradient_norm, epoch_1000x) + for submodule_name in param_groups_name_to_idx_map: + lr_name = ( + f"train_lr_{submodule_name}" + if submodule_name != "default" + else "train_lr" + ) + log_lr = optimizer.param_groups[ + param_groups_name_to_idx_map[submodule_name][0] + ]["lr"] + log_writer.add_scalar(lr_name, log_lr, epoch_1000x) + log_writer.add_scalar("train_iter", epoch_1000x, epoch_1000x) + for name, val in loss_details.items(): + log_writer.add_scalar("train_" + name, val, epoch_1000x) + + # # Gather the stats from all processes + # metric_logger.synchronize_between_processes() + # print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def test_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + data_loader: Sized, + device: torch.device, + epoch: int, + args, + log_writer=None, + prefix="test", +): + """ + Evaluates the model on a test dataset for one epoch. + Epoch is just a chunk of the entire dataset. + + This function runs evaluation on the test dataset without computing gradients, + and collects metrics for model performance assessment. + + Args: + model: The neural network model to evaluate. + criterion: Loss function for evaluation. + data_loader: DataLoader providing the test data. + device: Device to run evaluation on (CPU or GPU). + epoch: Current epoch number. + args: Configuration object containing evaluation parameters. + log_writer: Optional; TensorBoard SummaryWriter for logging. + prefix: String prefix for logging metrics. + + Returns: + dict: Dictionary containing evaluation metrics (average and median values). + """ + model.eval() + metric_logger = train_tools.MetricLogger(delimiter=" ") + metric_logger.meters = defaultdict( + lambda: train_tools.SmoothedValue(window_size=9**9) + ) + header = "Test Epoch: [{}]".format(epoch) + + if log_writer is not None: + print("log_dir: {}".format(log_writer.log_dir)) + + if args.train_params.freeze_val_samples_across_all_epochs: + dataloader_epoch = 0 + else: + dataloader_epoch = epoch + if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"): + data_loader.dataset.set_epoch(dataloader_epoch) + if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"): + data_loader.sampler.set_epoch(dataloader_epoch) + if hasattr(data_loader, "batch_sampler") and hasattr( + data_loader.batch_sampler, "set_epoch" + ): + data_loader.batch_sampler.set_epoch(dataloader_epoch) + + for _, batch in enumerate( + metric_logger.log_every(data_loader, args.train_params.print_freq, header) + ): + n_views = len(batch) + loss_tuple = loss_of_one_batch_multi_view( + batch, + model, + criterion, + device, + use_amp=bool(args.train_params.amp), + amp_dtype=args.train_params.amp_dtype, + ret="loss", + ) + loss_value, loss_details = loss_tuple # criterion returns two values + if n_views > 2: + loss_value = loss_value * ( + 2 / n_views + ) # scale the loss relative to the number of views (base is 2 views) + metric_logger.update(loss=float(loss_value), **loss_details) + + # # Gather the stats from all processes + # metric_logger.synchronize_between_processes() + # print("Averaged stats:", metric_logger) + + aggs = [("avg", "global_avg"), ("med", "median")] + results = { + f"{k}_{tag}": getattr(meter, attr) + for k, meter in metric_logger.meters.items() + for tag, attr in aggs + } + + if log_writer is not None: + for name, val in results.items(): + log_writer.add_scalar(prefix + "_" + name, val, 1000 * epoch) + + return results diff --git a/mapanything/utils/__init__.py b/mapanything/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/utils/__pycache__/__init__.cpython-312.pyc b/mapanything/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37b07eef2465cac9967f2db197a9e1f13df33dae Binary files /dev/null and b/mapanything/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/cropping.cpython-312.pyc b/mapanything/utils/__pycache__/cropping.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..561d5b0d3c3296ce4d2144e7ce465905a0db719c Binary files /dev/null and b/mapanything/utils/__pycache__/cropping.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/geometry.cpython-312.pyc b/mapanything/utils/__pycache__/geometry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d4e2645f45007db1fd058b3b0d0a0148a419e82 Binary files /dev/null and b/mapanything/utils/__pycache__/geometry.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/image.cpython-312.pyc b/mapanything/utils/__pycache__/image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4e770980ec6ce01e10d5ffe694be00d1832b653 Binary files /dev/null and b/mapanything/utils/__pycache__/image.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/inference.cpython-312.pyc b/mapanything/utils/__pycache__/inference.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..115b694644f562cd681ccd9e49d4f9f711b03a34 Binary files /dev/null and b/mapanything/utils/__pycache__/inference.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/misc.cpython-312.pyc b/mapanything/utils/__pycache__/misc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17ef08ae32e3dd1cb7d34f52833c28abd8c3c03f Binary files /dev/null and b/mapanything/utils/__pycache__/misc.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/warnings.cpython-312.pyc b/mapanything/utils/__pycache__/warnings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b8b9aec877c55d5ece038971032547c7b481f24 Binary files /dev/null and b/mapanything/utils/__pycache__/warnings.cpython-312.pyc differ diff --git a/mapanything/utils/colmap.py b/mapanything/utils/colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..c1965c0f345088dfd7666e4c455d83e9defa7257 --- /dev/null +++ b/mapanything/utils/colmap.py @@ -0,0 +1,662 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +""" +COLMAP Utils + +Source: https://github.com/colmap/colmap/blob/master/scripts/python/read_write_model.py (modified by Khiem Vuong) +""" + +import argparse +import collections +import os +import struct + +import numpy as np + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"] +) +Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] +) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] +) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), +} +CAMERA_MODEL_IDS = dict( + [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] +) +CAMERA_MODEL_NAMES = dict( + [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] +) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera( + id=camera_id, model=model, width=width, height=height, params=params + ) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for camera_line_index in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ" + ) + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes( + fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params + ) + cameras[camera_id] = Camera( + id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params), + ) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = [ + "# Camera list with one line of data per camera:\n" + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + "# Number of cameras: {}\n".format(len(cameras)) + ] + with open(path, "w") as fid: + fid.write("".join(HEADER)) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, model_id, cam.width, cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack( + [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] + ) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for image_index in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi" + ) + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] + x_y_id_s = read_next_bytes( + fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D, + ) + xys = np.column_stack( + [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] + ) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def write_images_text(images, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum( + (len(img.point3D_ids) for _, img in images.items()) + ) / len(images) + HEADER = [ + "# Image list with two lines of data per image:\n" + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + "# Number of images: {}, mean observations per image: {}\n".format( + len(images), mean_observations + ) + ] + + with open(path, "w") as fid: + fid.write("".join(HEADER)) + for _, img in images.items(): + image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def read_points3d_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for point_line_index in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd" + ) + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] + track_elems = read_next_bytes( + fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * track_length, + ) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum( + (len(pt.image_ids) for _, pt in points3D.items()) + ) / len(points3D) + HEADER = [ + "# 3D point list with one line of data per point:\n" + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + "# Number of points: {}, mean track length: {}\n".format( + len(points3D), mean_track_length + ) + ] + + with open(path, "w") as fid: + fid.write("".join(HEADER)) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3d_binary(points3D, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def read_model(path, ext): + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3d_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def get_camera_matrix(camera_params, camera_model): + """Get camera matrix and distortion coefficients from camera parameters in COLMAP format + Arguments + camera_params - Camera parameters in COLMAP format + camera_model - Camera model + Return + K - [3, 3] Camera matrix + dc - [12,] Distortion coefficients + """ + camera_params = np.asarray(camera_params) + K = np.zeros([3, 3]) + K[2, 2] = 1 + dc = np.zeros( + [ + 12, + ] + ) + if str.upper(camera_model) == "SIMPLE_PINHOLE": + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[0] + K[0, 2] = camera_params[1] + K[1, 2] = camera_params[2] + elif str.upper(camera_model) == "PINHOLE": + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[1] + K[0, 2] = camera_params[2] + K[1, 2] = camera_params[3] + elif str.upper(camera_model) in ("SIMPLE_RADIAL", "SIMPLE_RADIAL_FISHEYE"): + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[0] + K[0, 2] = camera_params[1] + K[1, 2] = camera_params[2] + dc[0] = camera_params[3] + elif str.upper(camera_model) in ("RADIAL", "RADIAL_FISHEYE"): + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[0] + K[0, 2] = camera_params[1] + K[1, 2] = camera_params[2] + dc[0:2] = camera_params[3:5] + elif str.upper(camera_model) == "OPENCV": + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[1] + K[0, 2] = camera_params[2] + K[1, 2] = camera_params[3] + dc[0:4] = camera_params[4:8] + elif str.upper(camera_model) == "FULL_OPENCV": + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[1] + K[0, 2] = camera_params[2] + K[1, 2] = camera_params[3] + dc[0:8] = camera_params[4:12] + elif str.upper(camera_model) == "OPENCV_FISHEYE": + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[1] + K[0, 2] = camera_params[2] + K[1, 2] = camera_params[3] + dc[0:2] = camera_params[4:6] + dc[4:6] = camera_params[6:8] + else: + raise ValueError("Unsupported camera model: " + camera_model) + return K, dc + + +def main(): + parser = argparse.ArgumentParser( + description="Read and write COLMAP binary and text models" + ) + parser.add_argument("--input_model", help="path to input model folder") + parser.add_argument( + "--input_format", + choices=[".bin", ".txt"], + help="input model format", + default="", + ) + parser.add_argument("--output_model", help="path to output model folder") + parser.add_argument( + "--output_format", + choices=[".bin", ".txt"], + help="output model format", + default=".txt", + ) + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model( + cameras, images, points3D, path=args.output_model, ext=args.output_format + ) + + +def main(): # noqa + parser = argparse.ArgumentParser( + description="Read and write COLMAP binary and text models" + ) + parser.add_argument("input_model", help="path to input model folder") + parser.add_argument( + "input_format", choices=[".bin", ".txt"], help="input model format" + ) + parser.add_argument( + "--output_model", metavar="PATH", help="path to output model folder" + ) + parser.add_argument( + "--output_format", + choices=[".bin", ".txt"], + help="output model format", + default=".txt", + ) + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model( + cameras, images, points3D, path=args.output_model, ext=args.output_format + ) + + +if __name__ == "__main__": + main() diff --git a/mapanything/utils/cropping.py b/mapanything/utils/cropping.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6c011f5b755f1a73907e95bd3f1f5f14dd3e78 --- /dev/null +++ b/mapanything/utils/cropping.py @@ -0,0 +1,467 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utility functions for cropping and resizing data while maintaining proper cameras. + +References: DUSt3R +""" + +import cv2 +import numpy as np +import PIL.Image + +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + +from mapanything.utils.geometry import ( + colmap_to_opencv_intrinsics, + opencv_to_colmap_intrinsics, +) + + +class ImageList: + """ + Convenience class to apply the same operation to a whole set of images. + + This class wraps a list of PIL.Image objects and provides methods to perform + operations on all images simultaneously. + """ + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + """Return the number of images in the list.""" + return len(self.images) + + def to_pil(self): + """ + Convert ImageList back to PIL Image(s). + + Returns: + PIL.Image.Image or tuple: Single PIL Image if list contains one image, + or tuple of PIL Images if multiple images + """ + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + """ + Get the size of images in the list. + + Returns: + tuple: (width, height) of the images + + Raises: + AssertionError: If images have different sizes + """ + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes), "All images must have the same size" + return sizes[0] + + def resize(self, *args, **kwargs): + """ + Resize all images with the same parameters. + + Args: + *args, **kwargs: Arguments passed to PIL.Image.resize() + + Returns: + ImageList: New ImageList containing resized images + """ + return ImageList(self._dispatch("resize", *args, **kwargs)) + + def crop(self, *args, **kwargs): + """ + Crop all images with the same parameters. + + Args: + *args, **kwargs: Arguments passed to PIL.Image.crop() + + Returns: + ImageList: New ImageList containing cropped images + """ + return ImageList(self._dispatch("crop", *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + """ + Apply a PIL.Image method to all images in the list. + + Args: + func (str): Name of the PIL.Image method to call + *args, **kwargs: Arguments to pass to the method + + Returns: + list: List of results from applying the method to each image + """ + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def resize_with_nearest_interpolation_to_match_aspect_ratio(input_data, img_h, img_w): + """ + Resize input map to match the aspect ratio of an image while ensuring + the input resolution never increases beyond the original. + Uses nearest interpolation for resizing. + + Args: + input_data (np.ndarray): The input map to resize + img_h (int): Height of the target image + img_w (int): Width of the target image + + Returns: + tuple: (resized_input, target_h, target_w) + - resized_input: The resized input map + - target_h: The target height used for resizing + - target_w: The target width used for resizing + """ + # Get the dimensions of the input map + input_h, input_w = input_data.shape[:2] + + # Calculate aspect ratios + img_aspect = img_w / img_h + + # Option 1: Keep input_w fixed and calculate new height + option1_h = int(input_w / img_aspect) + # Option 2: Keep input_h fixed and calculate new width + option2_w = int(input_h * img_aspect) + + # Check if either option would increase a dimension + option1_increases = option1_h > input_h + option2_increases = option2_w > input_w + + if option1_increases and option2_increases: + # Both options would increase a dimension, so we need to scale down both dimensions + # Find the scaling factor that preserves aspect ratio and ensures no dimension increases + scale_h = input_h / img_h + scale_w = input_w / img_w + scale = min(scale_h, scale_w) + + target_input_h = int(img_h * scale) + target_input_w = int(img_w * scale) + elif option1_increases: + # Option 1 would increase height, so use option 2 + target_input_h = input_h + target_input_w = option2_w + elif option2_increases: + # Option 2 would increase width, so use option 1 + target_input_w = input_w + target_input_h = option1_h + else: + # Neither option increases dimensions, choose the one that maintains resolution better + if abs(input_h * input_w - input_w * option1_h) < abs( + input_h * input_w - option2_w * input_h + ): + # Option 1 is better: keep width fixed, adjust height + target_input_w = input_w + target_input_h = option1_h + else: + # Option 2 is better: keep height fixed, adjust width + target_input_h = input_h + target_input_w = option2_w + + # Resize input using nearest interpolation to maintain input values + if target_input_h != input_h or target_input_w != input_w: + resized_input = cv2.resize( + input_data, + (target_input_w, target_input_h), + interpolation=cv2.INTER_NEAREST, + ) + else: + resized_input = input_data + + return resized_input, target_input_h, target_input_w + + +def rescale_image_and_other_optional_info( + image, + output_resolution, + depthmap=None, + camera_intrinsics=None, + force=True, + additional_quantities_to_be_resized_with_nearest=None, +): + """ + Rescale the image and depthmap to the output resolution. + If the image is larger than the output resolution, it is rescaled with lanczos interpolation. + If force is false and the image is smaller than the output resolution, it is not rescaled. + If force is true and the image is smaller than the output resolution, it is rescaled with bicubic interpolation. + Depth and other quantities are rescaled with nearest interpolation. + + Args: + image (PIL.Image.Image or np.ndarray): The input image to be rescaled. + output_resolution (tuple): The desired output resolution as a tuple (width, height). + depthmap (np.ndarray, optional): The depth map associated with the image. Defaults to None. + camera_intrinsics (np.ndarray, optional): The camera intrinsics matrix. Defaults to None. + force (bool, optional): If True, force rescaling even if the image is smaller than the output resolution. Defaults to True. + additional_quantities_to_be_resized_with_nearest (list of np.ndarray, optional): Additional quantities to be rescaled using nearest interpolation. Defaults to None. + + Returns: + tuple: A tuple containing: + - The rescaled image (PIL.Image.Image) + - The rescaled depthmap (numpy.ndarray or None) + - The updated camera intrinsics (numpy.ndarray or None) + - The list of rescaled additional quantities (list of numpy.ndarray or None) + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W, H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + assert tuple(depthmap.shape[:2]) == image.size[::-1] + if additional_quantities_to_be_resized_with_nearest is not None: + assert all( + tuple(additional_quantity.shape[:2]) == image.size[::-1] + for additional_quantity in additional_quantities_to_be_resized_with_nearest + ) + + # Define output resolution + assert output_resolution.shape == (2,) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + output = ( + image.to_pil(), + depthmap, + camera_intrinsics, + additional_quantities_to_be_resized_with_nearest, + ) + return output + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + # First rescale the image so that it contains the crop + image = image.resize( + tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic + ) + if depthmap is not None: + depthmap = cv2.resize( + depthmap, + output_resolution, + fx=scale_final, + fy=scale_final, + interpolation=cv2.INTER_NEAREST, + ) + if additional_quantities_to_be_resized_with_nearest is not None: + resized_additional_quantities = [] + for quantity in additional_quantities_to_be_resized_with_nearest: + resized_additional_quantities.append( + cv2.resize( + quantity, + output_resolution, + fx=scale_final, + fy=scale_final, + interpolation=cv2.INTER_NEAREST, + ) + ) + additional_quantities_to_be_resized_with_nearest = resized_additional_quantities + + # No offset here; simple rescaling + if camera_intrinsics is not None: + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final + ) + + # Return + return ( + image.to_pil(), + depthmap, + camera_intrinsics, + additional_quantities_to_be_resized_with_nearest, + ) + + +def camera_matrix_of_crop( + input_camera_matrix, + input_resolution, + output_resolution, + scaling=1, + offset_factor=0.5, + offset=None, +): + """ + Calculate the camera matrix for a cropped image. + + Args: + input_camera_matrix (numpy.ndarray): Original camera intrinsics matrix + input_resolution (tuple or numpy.ndarray): Original image resolution as (width, height) + output_resolution (tuple or numpy.ndarray): Target image resolution as (width, height) + scaling (float, optional): Scaling factor for the image. Defaults to 1. + offset_factor (float, optional): Factor to determine crop offset. Defaults to 0.5 (centered). + offset (tuple or numpy.ndarray, optional): Explicit offset to use. If None, calculated from offset_factor. + + Returns: + numpy.ndarray: Updated camera matrix for the cropped image + """ + # Margins to offset the origin + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + # Generate new camera parameters + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_and_other_optional_info( + image, + crop_bbox, + depthmap=None, + camera_intrinsics=None, + additional_quantities=None, +): + """ + Return a crop of the input view and associated data. + + Args: + image (PIL.Image.Image or numpy.ndarray): The input image to be cropped + crop_bbox (tuple): Crop bounding box as (left, top, right, bottom) + depthmap (numpy.ndarray, optional): Depth map associated with the image + camera_intrinsics (numpy.ndarray, optional): Camera intrinsics matrix + additional_quantities (list of numpy.ndarray, optional): Additional data arrays to crop + + Returns: + tuple: A tuple containing: + - The cropped image + - The cropped depth map (if provided or None) + - Updated camera intrinsics (if provided or None) + - List of cropped additional quantities (if provided or None) + """ + image = ImageList(image) + left, top, right, bottom = crop_bbox + + image = image.crop((left, top, right, bottom)) + if depthmap is not None: + depthmap = depthmap[top:bottom, left:right] + if additional_quantities is not None: + additional_quantities = [ + quantity[top:bottom, left:right] for quantity in additional_quantities + ] + + if camera_intrinsics is not None: + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= left + camera_intrinsics[1, 2] -= top + + return (image.to_pil(), depthmap, camera_intrinsics, additional_quantities) + + +def bbox_from_intrinsics_in_out( + input_camera_matrix, output_camera_matrix, output_resolution +): + """ + Calculate the bounding box for cropping based on input and output camera intrinsics. + + Args: + input_camera_matrix (numpy.ndarray): Original camera intrinsics matrix + output_camera_matrix (numpy.ndarray): Target camera intrinsics matrix + output_resolution (tuple): Target resolution as (width, height) + + Returns: + tuple: Crop bounding box as (left, top, right, bottom) + """ + out_width, out_height = output_resolution + left, top = np.int32( + np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]) + ) + crop_bbox = (left, top, left + out_width, top + out_height) + return crop_bbox + + +def crop_resize_if_necessary( + image, + resolution, + depthmap=None, + intrinsics=None, + additional_quantities=None, +): + """ + First downsample image using LANCZOS and then crop if necessary to achieve target resolution. + + This function performs high-quality downsampling followed by cropping to achieve the + desired output resolution while maintaining proper camera intrinsics. + + Args: + image (PIL.Image.Image or numpy.ndarray): The input image to be processed + resolution (tuple): Target resolution as (width, height) + depthmap (numpy.ndarray, optional): Depth map associated with the image + intrinsics (numpy.ndarray, optional): Camera intrinsics matrix + additional_quantities (list of numpy.ndarray, optional): Additional data arrays to process + + Returns: + tuple: A tuple containing the processed image and any provided additional data + (depthmap, intrinsics, additional_quantities) that have been similarly processed + """ + # Convert image to PIL.Image.Image if necessary + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # Get width and height of image + original_width, original_height = image.size + + # High-quality Lanczos down-scaling + target_rescale_resolution = np.array(resolution) + image, depthmap, intrinsics, additional_quantities = ( + rescale_image_and_other_optional_info( + image=image, + output_resolution=target_rescale_resolution, + depthmap=depthmap, + camera_intrinsics=intrinsics, + additional_quantities_to_be_resized_with_nearest=additional_quantities, + ) + ) + + # Actual cropping (if necessary) + if intrinsics is not None: + new_intrinsics = camera_matrix_of_crop( + input_camera_matrix=intrinsics, + input_resolution=image.size, + output_resolution=resolution, + offset_factor=0.5, + ) + crop_bbox = bbox_from_intrinsics_in_out( + input_camera_matrix=intrinsics, + output_camera_matrix=new_intrinsics, + output_resolution=resolution, + ) + else: + # Create a centered crop if no intrinsics are available + w, h = image.size + target_w, target_h = resolution + left = (w - target_w) // 2 + top = (h - target_h) // 2 + crop_bbox = (left, top, left + target_w, top + target_h) + + image, depthmap, new_intrinsics, additional_quantities = ( + crop_image_and_other_optional_info( + image=image, + crop_bbox=crop_bbox, + depthmap=depthmap, + camera_intrinsics=intrinsics, + additional_quantities=additional_quantities, + ) + ) + + # Return the output + output = (image,) + if depthmap is not None: + output += (depthmap,) + if new_intrinsics is not None: + output += (new_intrinsics,) + if additional_quantities is not None: + output += (additional_quantities,) + return output diff --git a/mapanything/utils/device.py b/mapanything/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..d28180666c386af9b675fd8dc29be28fe42d9f9f --- /dev/null +++ b/mapanything/utils/device.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utility functions for managing computation device +""" + +import numpy as np +import torch + + +def to_device(batch, device, callback=None, non_blocking=False): + """ + Transfer data to another device (i.e. GPU, CPU:torch, CPU:numpy). + + This function recursively processes nested data structures (lists, tuples, dicts) + and transfers each tensor to the specified device. + + Args: + batch: Data to transfer (list, tuple, dict of tensors or other objects) + device: Target device - pytorch device (e.g., 'cuda', 'cpu') or 'numpy' + callback: Optional function that would be called on every element before processing + non_blocking: If True, allows asynchronous copy to GPU (may be faster) + + Returns: + Data with the same structure as input but with tensors transferred to target device + """ + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return { + k: to_device(v, device, non_blocking=non_blocking) for k, v in batch.items() + } + + if isinstance(batch, (tuple, list)): + return type(batch)( + to_device(x, device, non_blocking=non_blocking) for x in batch + ) + + x = batch + if device == "numpy": + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +def to_numpy(x): + """Convert data to numpy arrays. + + Args: + x: Input data (can be tensor, array, or nested structure) + + Returns: + Data with the same structure but with tensors converted to numpy arrays + """ + return to_device(x, "numpy") + + +def to_cpu(x): + """Transfer data to CPU. + + Args: + x: Input data (can be tensor, array, or nested structure) + + Returns: + Data with the same structure but with tensors moved to CPU + """ + return to_device(x, "cpu") + + +def to_cuda(x): + """Transfer data to CUDA device (GPU). + + Args: + x: Input data (can be tensor, array, or nested structure) + + Returns: + Data with the same structure but with tensors moved to GPU + """ + return to_device(x, "cuda") diff --git a/mapanything/utils/geometry.py b/mapanything/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..9452a13e7a9f97ef59cd4ae48f8d271bbac38792 --- /dev/null +++ b/mapanything/utils/geometry.py @@ -0,0 +1,2188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utilities for geometry operations. + +References: DUSt3R, MoGe +""" + +from numbers import Number +from typing import Tuple, Union + +import einops as ein +import numpy as np +import torch +import torch.nn.functional as F + +from mapanything.utils.misc import invalid_to_zeros +from mapanything.utils.warnings import no_warnings + + +def depthmap_to_camera_frame(depthmap, intrinsics): + """ + Convert depth image to a pointcloud in camera frame. + + Args: + - depthmap: HxW or BxHxW torch tensor + - intrinsics: 3x3 or Bx3x3 torch tensor + + Returns: + pointmap in camera frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels. + """ + # Add batch dimension if not present + if depthmap.dim() == 2: + depthmap = depthmap.unsqueeze(0) + intrinsics = intrinsics.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + batch_size, height, width = depthmap.shape + device = depthmap.device + + # Compute 3D point in camera frame associated with each pixel + x_grid, y_grid = torch.meshgrid( + torch.arange(width, device=device).float(), + torch.arange(height, device=device).float(), + indexing="xy", + ) + x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) + y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) + + fx = intrinsics[:, 0, 0].view(-1, 1, 1) + fy = intrinsics[:, 1, 1].view(-1, 1, 1) + cx = intrinsics[:, 0, 2].view(-1, 1, 1) + cy = intrinsics[:, 1, 2].view(-1, 1, 1) + + depth_z = depthmap + xx = (x_grid - cx) * depth_z / fx + yy = (y_grid - cy) * depth_z / fy + pts3d_cam = torch.stack((xx, yy, depth_z), dim=-1) + + # Compute mask of valid non-zero depth pixels + valid_mask = depthmap > 0.0 + + # Remove batch dimension if it was added + if squeeze_batch_dim: + pts3d_cam = pts3d_cam.squeeze(0) + valid_mask = valid_mask.squeeze(0) + + return pts3d_cam, valid_mask + + +def depthmap_to_world_frame(depthmap, intrinsics, camera_pose=None): + """ + Convert depth image to a pointcloud in world frame. + + Args: + - depthmap: HxW or BxHxW torch tensor + - intrinsics: 3x3 or Bx3x3 torch tensor + - camera_pose: 4x4 or Bx4x4 torch tensor + + Returns: + pointmap in world frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels. + """ + pts3d_cam, valid_mask = depthmap_to_camera_frame(depthmap, intrinsics) + + if camera_pose is not None: + # Add batch dimension if not present + if camera_pose.dim() == 2: + camera_pose = camera_pose.unsqueeze(0) + pts3d_cam = pts3d_cam.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Convert points from camera frame to world frame + pts3d_cam_homo = torch.cat( + [pts3d_cam, torch.ones_like(pts3d_cam[..., :1])], dim=-1 + ) + pts3d_world = ein.einsum( + camera_pose, pts3d_cam_homo, "b i k, b h w k -> b h w i" + ) + pts3d_world = pts3d_world[..., :3] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + pts3d_world = pts3d_world.squeeze(0) + else: + pts3d_world = pts3d_cam + + return pts3d_world, valid_mask + + +def transform_pts3d(pts3d, transformation): + """ + Transform 3D points using a 4x4 transformation matrix. + + Args: + - pts3d: HxWx3 or BxHxWx3 torch tensor + - transformation: 4x4 or Bx4x4 torch tensor + + Returns: + transformed points (HxWx3 or BxHxWx3 tensor) + """ + # Add batch dimension if not present + if pts3d.dim() == 3: + pts3d = pts3d.unsqueeze(0) + transformation = transformation.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Convert points to homogeneous coordinates + pts3d_homo = torch.cat([pts3d, torch.ones_like(pts3d[..., :1])], dim=-1) + + # Transform points + transformed_pts3d = ein.einsum( + transformation, pts3d_homo, "b i k, b h w k -> b h w i" + ) + transformed_pts3d = transformed_pts3d[..., :3] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + transformed_pts3d = transformed_pts3d.squeeze(0) + + return transformed_pts3d + + +def project_pts3d_to_image(pts3d, intrinsics, return_z_dim): + """ + Project 3D points to image plane (assumes pinhole camera model with no distortion). + + Args: + - pts3d: HxWx3 or BxHxWx3 torch tensor + - intrinsics: 3x3 or Bx3x3 torch tensor + - return_z_dim: bool, whether to return the third dimension of the projected points + + Returns: + projected points (HxWx2) + """ + if pts3d.dim() == 3: + pts3d = pts3d.unsqueeze(0) + intrinsics = intrinsics.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Project points to image plane + projected_pts2d = ein.einsum(intrinsics, pts3d, "b i k, b h w k -> b h w i") + projected_pts2d[..., :2] /= projected_pts2d[..., 2].unsqueeze(-1).clamp(min=1e-6) + + # Remove the z dimension if not required + if not return_z_dim: + projected_pts2d = projected_pts2d[..., :2] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + projected_pts2d = projected_pts2d.squeeze(0) + + return projected_pts2d + + +def get_rays_in_camera_frame(intrinsics, height, width, normalize_to_unit_sphere): + """ + Convert camera intrinsics to a raymap (ray origins + directions) in camera frame. + Note: Currently only supports pinhole camera model. + + Args: + - intrinsics: 3x3 or Bx3x3 torch tensor + - height: int + - width: int + - normalize_to_unit_sphere: bool + + Returns: + - ray_origins: (HxWx3 or BxHxWx3) tensor + - ray_directions: (HxWx3 or BxHxWx3) tensor + """ + # Add batch dimension if not present + if intrinsics.dim() == 2: + intrinsics = intrinsics.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + batch_size = intrinsics.shape[0] + device = intrinsics.device + + # Compute rays in camera frame associated with each pixel + x_grid, y_grid = torch.meshgrid( + torch.arange(width, device=device).float(), + torch.arange(height, device=device).float(), + indexing="xy", + ) + x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) + y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) + + fx = intrinsics[:, 0, 0].view(-1, 1, 1) + fy = intrinsics[:, 1, 1].view(-1, 1, 1) + cx = intrinsics[:, 0, 2].view(-1, 1, 1) + cy = intrinsics[:, 1, 2].view(-1, 1, 1) + + ray_origins = torch.zeros((batch_size, height, width, 3), device=device) + xx = (x_grid - cx) / fx + yy = (y_grid - cy) / fy + ray_directions = torch.stack((xx, yy, torch.ones_like(xx)), dim=-1) + + # Normalize ray directions to unit sphere if required (else rays will lie on unit plane) + if normalize_to_unit_sphere: + ray_directions = ray_directions / torch.norm( + ray_directions, dim=-1, keepdim=True + ) + + # Remove batch dimension if it was added + if squeeze_batch_dim: + ray_origins = ray_origins.squeeze(0) + ray_directions = ray_directions.squeeze(0) + + return ray_origins, ray_directions + + +def get_rays_in_world_frame( + intrinsics, height, width, normalize_to_unit_sphere, camera_pose=None +): + """ + Convert camera intrinsics & camera_pose (if provided) to a raymap (ray origins + directions) in camera or world frame (if camera_pose is provided). + Note: Currently only supports pinhole camera model. + + Args: + - intrinsics: 3x3 or Bx3x3 torch tensor + - height: int + - width: int + - normalize_to_unit_sphere: bool + - camera_pose: 4x4 or Bx4x4 torch tensor + + Returns: + - ray_origins: (HxWx3 or BxHxWx3) tensor + - ray_directions: (HxWx3 or BxHxWx3) tensor + """ + # Get rays in camera frame + ray_origins, ray_directions = get_rays_in_camera_frame( + intrinsics, height, width, normalize_to_unit_sphere + ) + + if camera_pose is not None: + # Add batch dimension if not present + if camera_pose.dim() == 2: + camera_pose = camera_pose.unsqueeze(0) + ray_origins = ray_origins.unsqueeze(0) + ray_directions = ray_directions.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Convert rays from camera frame to world frame + ray_origins_homo = torch.cat( + [ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1 + ) + ray_directions_homo = torch.cat( + [ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1 + ) + ray_origins_world = ein.einsum( + camera_pose, ray_origins_homo, "b i k, b h w k -> b h w i" + ) + ray_directions_world = ein.einsum( + camera_pose, ray_directions_homo, "b i k, b h w k -> b h w i" + ) + ray_origins_world = ray_origins_world[..., :3] + ray_directions_world = ray_directions_world[..., :3] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + ray_origins_world = ray_origins_world.squeeze(0) + ray_directions_world = ray_directions_world.squeeze(0) + else: + ray_origins_world = ray_origins + ray_directions_world = ray_directions + + return ray_origins_world, ray_directions_world + + +def recover_pinhole_intrinsics_from_ray_directions( + ray_directions, use_geometric_calculation=False +): + """ + Recover pinhole camera intrinsics from ray directions, supporting both batched and non-batched inputs. + + Args: + ray_directions: Tensor of shape [H, W, 3] or [B, H, W, 3] containing unit normalized ray directions + + Returns: + Dictionary containing camera intrinsics (fx, fy, cx, cy) as tensors + """ + # Add batch dimension if not present + if ray_directions.dim() == 3: # [H, W, 3] + ray_directions = ray_directions.unsqueeze(0) # [1, H, W, 3] + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + batch_size, height, width, _ = ray_directions.shape + device = ray_directions.device + + # Create pixel coordinate grid + x_grid, y_grid = torch.meshgrid( + torch.arange(width, device=device).float(), + torch.arange(height, device=device).float(), + indexing="xy", + ) + + # Expand grid for all batches + x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W] + y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W] + + # Determine if high resolution or not + is_high_res = height * width > 1000000 + + if is_high_res or use_geometric_calculation: + # For high-resolution cases, use direct geometric calculation + # Define key points + center_h, center_w = height // 2, width // 2 + quarter_w, three_quarter_w = width // 4, 3 * width // 4 + quarter_h, three_quarter_h = height // 4, 3 * height // 4 + + # Get rays at key points + center_rays = ray_directions[:, center_h, center_w, :].clone() # [B, 3] + left_rays = ray_directions[:, center_h, quarter_w, :].clone() # [B, 3] + right_rays = ray_directions[:, center_h, three_quarter_w, :].clone() # [B, 3] + top_rays = ray_directions[:, quarter_h, center_w, :].clone() # [B, 3] + bottom_rays = ray_directions[:, three_quarter_h, center_w, :].clone() # [B, 3] + + # Normalize rays to have dz = 1 + center_rays = center_rays / center_rays[:, 2].unsqueeze(1) # [B, 3] + left_rays = left_rays / left_rays[:, 2].unsqueeze(1) # [B, 3] + right_rays = right_rays / right_rays[:, 2].unsqueeze(1) # [B, 3] + top_rays = top_rays / top_rays[:, 2].unsqueeze(1) # [B, 3] + bottom_rays = bottom_rays / bottom_rays[:, 2].unsqueeze(1) # [B, 3] + + # Calculate fx directly (vectorized across batch) + fx_left = (quarter_w - center_w) / (left_rays[:, 0] - center_rays[:, 0]) + fx_right = (three_quarter_w - center_w) / (right_rays[:, 0] - center_rays[:, 0]) + fx = (fx_left + fx_right) / 2 # Average for robustness + + # Calculate cx + cx = center_w - fx * center_rays[:, 0] + + # Calculate fy and cy + fy_top = (quarter_h - center_h) / (top_rays[:, 1] - center_rays[:, 1]) + fy_bottom = (three_quarter_h - center_h) / ( + bottom_rays[:, 1] - center_rays[:, 1] + ) + fy = (fy_top + fy_bottom) / 2 + + cy = center_h - fy * center_rays[:, 1] + else: + # For standard resolution, use regression with sampling for efficiency + # Sample a grid of points (but more dense than for high-res) + step_h = max(1, height // 50) + step_w = max(1, width // 50) + + h_indices = torch.arange(0, height, step_h, device=device) + w_indices = torch.arange(0, width, step_w, device=device) + + # Extract subset of coordinates + x_sampled = x_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W'] + y_sampled = y_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W'] + rays_sampled = ray_directions[ + :, h_indices[:, None], w_indices[None, :], : + ] # [B, H', W', 3] + + # Reshape for linear regression + x_flat = x_sampled.reshape(batch_size, -1) # [B, N] + y_flat = y_sampled.reshape(batch_size, -1) # [B, N] + + # Extract ray direction components + dx = rays_sampled[..., 0].reshape(batch_size, -1) # [B, N] + dy = rays_sampled[..., 1].reshape(batch_size, -1) # [B, N] + dz = rays_sampled[..., 2].reshape(batch_size, -1) # [B, N] + + # Compute ratios for linear regression + ratio_x = dx / dz # [B, N] + ratio_y = dy / dz # [B, N] + + # Since torch.linalg.lstsq doesn't support batched input, we'll use a different approach + # For x-direction: x = cx + fx * (dx/dz) + # We can solve this using normal equations: A^T A x = A^T b + # Create design matrices + ones = torch.ones_like(x_flat) # [B, N] + A_x = torch.stack([ones, ratio_x], dim=2) # [B, N, 2] + b_x = x_flat.unsqueeze(2) # [B, N, 1] + + # Compute A^T A and A^T b for each batch + ATA_x = torch.bmm(A_x.transpose(1, 2), A_x) # [B, 2, 2] + ATb_x = torch.bmm(A_x.transpose(1, 2), b_x) # [B, 2, 1] + + # Solve the system for each batch + solution_x = torch.linalg.solve(ATA_x, ATb_x).squeeze(2) # [B, 2] + cx, fx = solution_x[:, 0], solution_x[:, 1] + + # Repeat for y-direction + A_y = torch.stack([ones, ratio_y], dim=2) # [B, N, 2] + b_y = y_flat.unsqueeze(2) # [B, N, 1] + + ATA_y = torch.bmm(A_y.transpose(1, 2), A_y) # [B, 2, 2] + ATb_y = torch.bmm(A_y.transpose(1, 2), b_y) # [B, 2, 1] + + solution_y = torch.linalg.solve(ATA_y, ATb_y).squeeze(2) # [B, 2] + cy, fy = solution_y[:, 0], solution_y[:, 1] + + # Create intrinsics matrices + batch_size = fx.shape[0] + intrinsics = torch.zeros(batch_size, 3, 3, device=ray_directions.device) + + # Fill in the intrinsics matrices + intrinsics[:, 0, 0] = fx # focal length x + intrinsics[:, 1, 1] = fy # focal length y + intrinsics[:, 0, 2] = cx # principal point x + intrinsics[:, 1, 2] = cy # principal point y + intrinsics[:, 2, 2] = 1.0 # bottom-right element is always 1 + + # Remove batch dimension if it was added + if squeeze_batch_dim: + intrinsics = intrinsics.squeeze(0) + + return intrinsics + + +def transform_rays(ray_origins, ray_directions, transformation): + """ + Transform 6D rays (ray origins and ray directions) using a 4x4 transformation matrix. + + Args: + - ray_origins: HxWx3 or BxHxWx3 torch tensor + - ray_directions: HxWx3 or BxHxWx3 torch tensor + - transformation: 4x4 or Bx4x4 torch tensor + - normalize_to_unit_sphere: bool, whether to normalize the transformed ray directions to unit length + + Returns: + transformed ray_origins (HxWx3 or BxHxWx3 tensor) and ray_directions (HxWx3 or BxHxWx3 tensor) + """ + # Add batch dimension if not present + if ray_origins.dim() == 3: + ray_origins = ray_origins.unsqueeze(0) + ray_directions = ray_directions.unsqueeze(0) + transformation = transformation.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Transform ray origins and directions + ray_origins_homo = torch.cat( + [ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1 + ) + ray_directions_homo = torch.cat( + [ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1 + ) + transformed_ray_origins = ein.einsum( + transformation, ray_origins_homo, "b i k, b h w k -> b h w i" + ) + transformed_ray_directions = ein.einsum( + transformation, ray_directions_homo, "b i k, b h w k -> b h w i" + ) + transformed_ray_origins = transformed_ray_origins[..., :3] + transformed_ray_directions = transformed_ray_directions[..., :3] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + transformed_ray_origins = transformed_ray_origins.squeeze(0) + transformed_ray_directions = transformed_ray_directions.squeeze(0) + + return transformed_ray_origins, transformed_ray_directions + + +def convert_z_depth_to_depth_along_ray(z_depth, intrinsics): + """ + Convert z-depth image to depth along camera rays. + + Args: + - z_depth: HxW or BxHxW torch tensor + - intrinsics: 3x3 or Bx3x3 torch tensor + + Returns: + - depth_along_ray: HxW or BxHxW torch tensor + """ + # Add batch dimension if not present + if z_depth.dim() == 2: + z_depth = z_depth.unsqueeze(0) + intrinsics = intrinsics.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Get rays in camera frame + batch_size, height, width = z_depth.shape + _, ray_directions = get_rays_in_camera_frame( + intrinsics, height, width, normalize_to_unit_sphere=False + ) + + # Compute depth along ray + pts3d_cam = z_depth[..., None] * ray_directions + depth_along_ray = torch.norm(pts3d_cam, dim=-1) + + # Remove batch dimension if it was added + if squeeze_batch_dim: + depth_along_ray = depth_along_ray.squeeze(0) + + return depth_along_ray + + +def convert_raymap_z_depth_quats_to_pointmap(ray_origins, ray_directions, depth, quats): + """ + Convert raymap (ray origins + directions on unit plane), z-depth and + unit quaternions (representing rotation) to a pointmap in world frame. + + Args: + - ray_origins: (HxWx3 or BxHxWx3) torch tensor + - ray_directions: (HxWx3 or BxHxWx3) torch tensor + - depth: (HxWx1 or BxHxWx1) torch tensor + - quats: (HxWx4 or BxHxWx4) torch tensor (unit quaternions and notation is (x, y, z, w)) + + Returns: + - pointmap: (HxWx3 or BxHxWx3) torch tensor + """ + # Add batch dimension if not present + if ray_origins.dim() == 3: + ray_origins = ray_origins.unsqueeze(0) + ray_directions = ray_directions.unsqueeze(0) + depth = depth.unsqueeze(0) + quats = quats.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + batch_size, height, width, _ = depth.shape + device = depth.device + + # Normalize the quaternions to ensure they are unit quaternions + quats = quats / torch.norm(quats, dim=-1, keepdim=True) + + # Convert quaternions to pixel-wise rotation matrices + qx, qy, qz, qw = quats[..., 0], quats[..., 1], quats[..., 2], quats[..., 3] + rot_mat = ( + torch.stack( + [ + qw**2 + qx**2 - qy**2 - qz**2, + 2 * (qx * qy - qw * qz), + 2 * (qw * qy + qx * qz), + 2 * (qw * qz + qx * qy), + qw**2 - qx**2 + qy**2 - qz**2, + 2 * (qy * qz - qw * qx), + 2 * (qx * qz - qw * qy), + 2 * (qw * qx + qy * qz), + qw**2 - qx**2 - qy**2 + qz**2, + ], + dim=-1, + ) + .reshape(batch_size, height, width, 3, 3) + .to(device) + ) + + # Compute 3D points in local camera frame + pts3d_local = depth * ray_directions + + # Rotate the local points using the quaternions + rotated_pts3d_local = ein.einsum( + rot_mat, pts3d_local, "b h w i k, b h w k -> b h w i" + ) + + # Compute 3D point in world frame associated with each pixel + pts3d = ray_origins + rotated_pts3d_local + + # Remove batch dimension if it was added + if squeeze_batch_dim: + pts3d = pts3d.squeeze(0) + + return pts3d + + +def quaternion_to_rotation_matrix(quat): + """ + Convert a quaternion into a 3x3 rotation matrix. + + Args: + - quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + + Returns: + - rot_matrix: 3x3 or Bx3x3 torch tensor + """ + if quat.dim() == 1: + quat = quat.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Ensure the quaternion is normalized + quat = quat / quat.norm(dim=1, keepdim=True) + x, y, z, w = quat.unbind(dim=1) + + # Compute the rotation matrix elements + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + + # Construct the rotation matrix + rot_matrix = torch.stack( + [ + 1 - 2 * (yy + zz), + 2 * (xy - wz), + 2 * (xz + wy), + 2 * (xy + wz), + 1 - 2 * (xx + zz), + 2 * (yz - wx), + 2 * (xz - wy), + 2 * (yz + wx), + 1 - 2 * (xx + yy), + ], + dim=1, + ).view(-1, 3, 3) + + # Squeeze batch dimension if it was unsqueezed + if squeeze_batch_dim: + rot_matrix = rot_matrix.squeeze(0) + + return rot_matrix + + +def rotation_matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def quaternion_inverse(quat): + """ + Compute the inverse of a quaternion. + + Args: + - quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + + Returns: + - inv_quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + """ + # Unsqueeze batch dimension if not present + if quat.dim() == 1: + quat = quat.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Compute the inverse + quat_conj = quat.clone() + quat_conj[:, :3] = -quat_conj[:, :3] + quat_norm = torch.sum(quat * quat, dim=1, keepdim=True) + inv_quat = quat_conj / quat_norm + + # Squeeze batch dimension if it was unsqueezed + if squeeze_batch_dim: + inv_quat = inv_quat.squeeze(0) + + return inv_quat + + +def quaternion_multiply(q1, q2): + """ + Multiply two quaternions. + + Args: + - q1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + - q2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + + Returns: + - qm: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + """ + # Unsqueeze batch dimension if not present + if q1.dim() == 1: + q1 = q1.unsqueeze(0) + q2 = q2.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Unbind the quaternions + x1, y1, z1, w1 = q1.unbind(dim=1) + x2, y2, z2, w2 = q2.unbind(dim=1) + + # Compute the product + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + + # Stack the components + qm = torch.stack([x, y, z, w], dim=1) + + # Squeeze batch dimension if it was unsqueezed + if squeeze_batch_dim: + qm = qm.squeeze(0) + + return qm + + +def transform_pose_using_quats_and_trans_2_to_1(quats1, trans1, quats2, trans2): + """ + Transform quats and translation of pose2 from absolute frame (pose2 to world) to relative frame (pose2 to pose1). + + Args: + - quats1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + - trans1: 3 or Bx3 torch tensor + - quats2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + - trans2: 3 or Bx3 torch tensor + + Returns: + - quats: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + - trans: 3 or Bx3 torch tensor + """ + # Unsqueeze batch dimension if not present + if quats1.dim() == 1: + quats1 = quats1.unsqueeze(0) + trans1 = trans1.unsqueeze(0) + quats2 = quats2.unsqueeze(0) + trans2 = trans2.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Compute the inverse of view1's pose + inv_quats1 = quaternion_inverse(quats1) + R1_inv = quaternion_to_rotation_matrix(inv_quats1) + t1_inv = -1 * ein.einsum(R1_inv, trans1, "b i j, b j -> b i") + + # Transform view2's pose to view1's frame + quats = quaternion_multiply(inv_quats1, quats2) + trans = ein.einsum(R1_inv, trans2, "b i j, b j -> b i") + t1_inv + + # Squeeze batch dimension if it was unsqueezed + if squeeze_batch_dim: + quats = quats.squeeze(0) + trans = trans.squeeze(0) + + return quats, trans + + +def convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + ray_directions, depth_along_ray, pose_trans, pose_quats +): + """ + Convert ray directions, depth along ray, pose translation, and + unit quaternions (representing pose rotation) to a pointmap in world frame. + + Args: + - ray_directions: (HxWx3 or BxHxWx3) torch tensor + - depth_along_ray: (HxWx1 or BxHxWx1) torch tensor + - pose_trans: (3 or Bx3) torch tensor + - pose_quats: (4 or Bx4) torch tensor (unit quaternions and notation is (x, y, z, w)) + + Returns: + - pointmap: (HxWx3 or BxHxWx3) torch tensor + """ + # Add batch dimension if not present + if ray_directions.dim() == 3: + ray_directions = ray_directions.unsqueeze(0) + depth_along_ray = depth_along_ray.unsqueeze(0) + pose_trans = pose_trans.unsqueeze(0) + pose_quats = pose_quats.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + batch_size, height, width, _ = depth_along_ray.shape + device = depth_along_ray.device + + # Normalize the quaternions to ensure they are unit quaternions + pose_quats = pose_quats / torch.norm(pose_quats, dim=-1, keepdim=True) + + # Convert quaternions to rotation matrices (B x 3 x 3) + rot_mat = quaternion_to_rotation_matrix(pose_quats) + + # Get pose matrix (B x 4 x 4) + pose_mat = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1) + pose_mat[:, :3, :3] = rot_mat + pose_mat[:, :3, 3] = pose_trans + + # Compute 3D points in local camera frame + pts3d_local = depth_along_ray * ray_directions + + # Compute 3D points in world frame + pts3d_homo = torch.cat([pts3d_local, torch.ones_like(pts3d_local[..., :1])], dim=-1) + pts3d_world = ein.einsum(pose_mat, pts3d_homo, "b i k, b h w k -> b h w i") + pts3d_world = pts3d_world[..., :3] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + pts3d_world = pts3d_world.squeeze(0) + + return pts3d_world + + +def xy_grid( + W, + H, + device=None, + origin=(0, 0), + unsqueeze=None, + cat_dim=-1, + homogeneous=False, + **arange_kw, +): + """ + Generate a coordinate grid of shape (H,W,2) or (H,W,3) if homogeneous=True. + + Args: + W (int): Width of the grid + H (int): Height of the grid + device (torch.device, optional): Device to place the grid on. If None, uses numpy arrays + origin (tuple, optional): Origin coordinates (x,y) for the grid. Default is (0,0) + unsqueeze (int, optional): Dimension to unsqueeze in the output tensors + cat_dim (int, optional): Dimension to concatenate the x,y coordinates. If None, returns tuple + homogeneous (bool, optional): If True, adds a third dimension of ones to make homogeneous coordinates + **arange_kw: Additional keyword arguments passed to np.arange or torch.arange + + Returns: + numpy.ndarray or torch.Tensor: Coordinate grid where: + - output[j,i,0] = i + origin[0] (x-coordinate) + - output[j,i,1] = j + origin[1] (y-coordinate) + - output[j,i,2] = 1 (if homogeneous=True) + """ + if device is None: + # numpy + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + # torch + def arange(*a, **kw): + return torch.arange(*a, device=device, **kw) + + meshgrid, stack = torch.meshgrid, torch.stack + + def ones(*a): + return torch.ones(*a, device=device) + + tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing="xy") + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """ + Apply a geometric transformation to a set of 3-D points. + + Args: + Trf: 3x3 or 4x4 projection matrix (typically a Homography) or batch of matrices + with shape (B, 3, 3) or (B, 4, 4) + pts: numpy/torch/tuple of coordinates with shape (..., 2) or (..., 3) + ncol: int, number of columns of the result (2 or 3) + norm: float, if not 0, the result is projected on the z=norm plane + (homogeneous normalization) + + Returns: + Array or tensor of projected points with the same type as input and shape (..., ncol) + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # Adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # Optimized code + if ( + isinstance(Trf, torch.Tensor) + and isinstance(pts, torch.Tensor) + and Trf.ndim == 3 + and pts.ndim == 4 + ): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = ( + torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + + Trf[:, None, None, :d, d] + ) + else: + raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /=, it will lead to a bug + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + + return res + + +def inv(mat): + """ + Invert a torch or numpy matrix + """ + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f"bad matrix type = {type(mat)}") + + +def closed_form_pose_inverse( + pose_matrices, rotation_matrices=None, translation_vectors=None +): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 pose matrices in a batch. + + If `rotation_matrices` and `translation_vectors` are provided, they must correspond to the rotation and translation + components of `pose_matrices`. Otherwise, they will be extracted from `pose_matrices`. + + Args: + pose_matrices: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + rotation_matrices (optional): Nx3x3 array or tensor of rotation matrices. + translation_vectors (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as input `pose_matrices`. + + Shapes: + pose_matrices: (N, 4, 4) + rotation_matrices: (N, 3, 3) + translation_vectors: (N, 3, 1) + """ + # Check if pose_matrices is a numpy array or a torch tensor + is_numpy = isinstance(pose_matrices, np.ndarray) + + # Validate shapes + if pose_matrices.shape[-2:] != (4, 4) and pose_matrices.shape[-2:] != (3, 4): + raise ValueError( + f"pose_matrices must be of shape (N,4,4), got {pose_matrices.shape}." + ) + + # Extract rotation_matrices and translation_vectors if not provided + if rotation_matrices is None: + rotation_matrices = pose_matrices[:, :3, :3] + if translation_vectors is None: + translation_vectors = pose_matrices[:, :3, 3:] + + # Compute the inverse of input SE3 matrices + if is_numpy: + rotation_transposed = np.transpose(rotation_matrices, (0, 2, 1)) + new_translation = -np.matmul(rotation_transposed, translation_vectors) + inverted_matrix = np.tile(np.eye(4), (len(rotation_matrices), 1, 1)) + else: + rotation_transposed = rotation_matrices.transpose(1, 2) + new_translation = -torch.bmm(rotation_transposed, translation_vectors) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(rotation_matrices), 1, 1) + inverted_matrix = inverted_matrix.to(rotation_matrices.dtype).to( + rotation_matrices.device + ) + inverted_matrix[:, :3, :3] = rotation_transposed + inverted_matrix[:, :3, 3:] = new_translation + + return inverted_matrix + + +def relative_pose_transformation(trans_01, trans_02): + r""" + Function that computes the relative homogenous transformation from a + reference transformation :math:`T_1^{0} = \begin{bmatrix} R_1 & t_1 \\ + \mathbf{0} & 1 \end{bmatrix}` to destination :math:`T_2^{0} = + \begin{bmatrix} R_2 & t_2 \\ \mathbf{0} & 1 \end{bmatrix}`. + + The relative transformation is computed as follows: + + .. math:: + + T_1^{2} = (T_0^{1})^{-1} \cdot T_0^{2} + + Arguments: + trans_01 (torch.Tensor): reference transformation tensor of shape + :math:`(N, 4, 4)` or :math:`(4, 4)`. + trans_02 (torch.Tensor): destination transformation tensor of shape + :math:`(N, 4, 4)` or :math:`(4, 4)`. + + Shape: + - Output: :math:`(N, 4, 4)` or :math:`(4, 4)`. + + Returns: + torch.Tensor: the relative transformation between the transformations. + + Example:: + >>> trans_01 = torch.eye(4) # 4x4 + >>> trans_02 = torch.eye(4) # 4x4 + >>> trans_12 = relative_transformation(trans_01, trans_02) # 4x4 + """ + if not torch.is_tensor(trans_01): + raise TypeError( + "Input trans_01 type is not a torch.Tensor. Got {}".format(type(trans_01)) + ) + if not torch.is_tensor(trans_02): + raise TypeError( + "Input trans_02 type is not a torch.Tensor. Got {}".format(type(trans_02)) + ) + if trans_01.dim() not in (2, 3) and trans_01.shape[-2:] == (4, 4): + raise ValueError( + "Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_01.shape) + ) + if trans_02.dim() not in (2, 3) and trans_02.shape[-2:] == (4, 4): + raise ValueError( + "Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_02.shape) + ) + if not trans_01.dim() == trans_02.dim(): + raise ValueError( + "Input number of dims must match. Got {} and {}".format( + trans_01.dim(), trans_02.dim() + ) + ) + + # Convert to Nx4x4 if inputs are 4x4 + squeeze_batch_dim = False + if trans_01.dim() == 2: + trans_01 = trans_01.unsqueeze(0) + trans_02 = trans_02.unsqueeze(0) + squeeze_batch_dim = True + + # Compute inverse of trans_01 using closed form + trans_10 = closed_form_pose_inverse(trans_01) + + # Compose transformations using matrix multiplication + trans_12 = torch.matmul(trans_10, trans_02) + + # Remove batch dimension if it was added + if squeeze_batch_dim: + trans_12 = trans_12.squeeze(0) + + return trans_12 + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: # [B,H,W] + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + # set principal point + if pp is None: + grid_x = grid_x - (W - 1) / 2 + grid_y = grid_y - (H - 1) / 2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + # Mask for valid coordinates + valid_mask = depthmap > 0.0 + + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates( + depthmap, camera_intrinsics, camera_pose, **kw +): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + X_world = X_cam # default + if camera_pose is not None: + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + X_world = ( + np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + ) + + return X_world, valid_mask + + +def get_absolute_pointmaps_and_rays_info( + depthmap, camera_intrinsics, camera_pose, **kw +): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), + a mask specifying valid pixels, + ray origins of absolute coordinates (HxWx3 array), + ray directions of absolute coordinates (HxWx3 array), + depth along ray (HxWx1 array), + ray directions of camera/local coordinates (HxWx3 array), + pointmap of camera/local coordinates (HxWx3 array). + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: pinhole & there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + # Get the rays on the unit plane + u, v = np.meshgrid(np.arange(W), np.arange(H)) + x_cam = (u - cu) / fu + y_cam = (v - cv) / fv + z_cam = np.ones_like(x_cam) + ray_dirs_cam_on_unit_plane = np.stack((x_cam, y_cam, z_cam), axis=-1).astype( + np.float32 + ) + + # Compute the 3d points in the local camera coordinate system + pts_cam = depthmap[..., None] * ray_dirs_cam_on_unit_plane + + # Get the depth along the ray and compute the ray directions on the unit sphere + depth_along_ray = np.linalg.norm(pts_cam, axis=-1, keepdims=True) + ray_directions_cam = ray_dirs_cam_on_unit_plane / np.linalg.norm( + ray_dirs_cam_on_unit_plane, axis=-1, keepdims=True + ) + + # Mask for valid coordinates + valid_mask = depthmap > 0.0 + + # Get the ray origins in absolute coordinates and the ray directions in absolute coordinates + ray_origins_world = np.zeros_like(ray_directions_cam) + ray_directions_world = ray_directions_cam + pts_world = pts_cam + if camera_pose is not None: + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates + ray_origins_world = ray_origins_world + t_cam2world[None, None, :] + ray_directions_world = np.einsum( + "ik, vuk -> vui", R_cam2world, ray_directions_cam + ) + pts_world = ray_origins_world + ray_directions_world * depth_along_ray + + return ( + pts_world, + valid_mask, + ray_origins_world, + ray_directions_world, + depth_along_ray, + ray_directions_cam, + pts_cam, + ) + + +def adjust_camera_params_for_rotation(camera_params, original_size, k): + """ + Adjust camera parameters for rotation. + + Args: + camera_params: Camera parameters [fx, fy, cx, cy, ...] + original_size: Original image size as (width, height) + k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise) + + Returns: + Adjusted camera parameters + """ + fx, fy, cx, cy = camera_params[:4] + width, height = original_size + + if k % 4 == 1: # 90 degrees counter-clockwise + new_fx, new_fy = fy, fx + new_cx, new_cy = height - cy, cx + elif k % 4 == 2: # 180 degrees + new_fx, new_fy = fx, fy + new_cx, new_cy = width - cx, height - cy + elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise) + new_fx, new_fy = fy, fx + new_cx, new_cy = cy, width - cx + else: # No rotation + return camera_params + + adjusted_params = [new_fx, new_fy, new_cx, new_cy] + if len(camera_params) > 4: + adjusted_params.extend(camera_params[4:]) + + return adjusted_params + + +def adjust_pose_for_rotation(pose, k): + """ + Adjust camera pose for rotation. + + Args: + pose: 4x4 camera pose matrix (camera-to-world, OpenCV convention - X right, Y down, Z forward) + k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise) + + Returns: + Adjusted 4x4 camera pose matrix + """ + # Create rotation matrices for different rotations + if k % 4 == 1: # 90 degrees counter-clockwise + rot_transform = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) + elif k % 4 == 2: # 180 degrees + rot_transform = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]) + elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise) + rot_transform = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + else: # No rotation + return pose + + # Apply the transformation to the pose + adjusted_pose = pose + adjusted_pose[:3, :3] = adjusted_pose[:3, :3] @ rot_transform.T + + return adjusted_pose + + +def crop_to_aspect_ratio(image, depth, camera_params, target_ratio=1.5): + """ + Crop image and depth to the largest possible target aspect ratio while + keeping the left side if aspect ratio is wider and the bottom of image if the aspect ratio is taller. + + Args: + image: PIL image + depth: Depth map as numpy array + camera_params: Camera parameters [fx, fy, cx, cy, ...] + target_ratio: Target width/height ratio + + Returns: + Cropped image, cropped depth, adjusted camera parameters + """ + width, height = image.size + fx, fy, cx, cy = camera_params[:4] + current_ratio = width / height + + if abs(current_ratio - target_ratio) < 1e-6: + # Already at target ratio + return image, depth, camera_params + + if current_ratio > target_ratio: + # Image is wider than target ratio, crop width + new_width = int(height * target_ratio) + left = 0 + right = new_width + + # Crop image + cropped_image = image.crop((left, 0, right, height)) + + # Crop depth + if len(depth.shape) == 3: + cropped_depth = depth[:, left:right, :] + else: + cropped_depth = depth[:, left:right] + + # Adjust camera parameters + new_cx = cx - left + adjusted_params = [fx, fy, new_cx, cy] + list(camera_params[4:]) + + else: + # Image is taller than target ratio, crop height + new_height = int(width / target_ratio) + top = max(0, height - new_height) + bottom = height + + # Crop image + cropped_image = image.crop((0, top, width, bottom)) + + # Crop depth + if len(depth.shape) == 3: + cropped_depth = depth[top:bottom, :, :] + else: + cropped_depth = depth[top:bottom, :] + + # Adjust camera parameters + new_cy = cy - top + adjusted_params = [fx, fy, cx, new_cy] + list(camera_params[4:]) + + return cropped_image, cropped_depth, adjusted_params + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + + return K + + +def normalize_depth_using_non_zero_pixels(depth, return_norm_factor=False): + """ + Normalize the depth by the average depth of non-zero depth pixels. + + Args: + depth (torch.Tensor): Depth tensor of size [B, H, W, 1]. + Returns: + normalized_depth (torch.Tensor): Normalized depth tensor. + norm_factor (torch.Tensor): Norm factor tensor of size B. + """ + assert depth.ndim == 4 and depth.shape[3] == 1 + # Calculate the sum and count of non-zero depth pixels for each batch + valid_depth_mask = depth > 0 + valid_sum = torch.sum(depth * valid_depth_mask, dim=(1, 2, 3)) + valid_count = torch.sum(valid_depth_mask, dim=(1, 2, 3)) + + # Calculate the norm factor + norm_factor = valid_sum / (valid_count + 1e-8) + while norm_factor.ndim < depth.ndim: + norm_factor.unsqueeze_(-1) + + # Normalize the depth by the norm factor + norm_factor = norm_factor.clip(min=1e-8) + normalized_depth = depth / norm_factor + + # Create the output tuple + output = ( + (normalized_depth, norm_factor.squeeze(-1).squeeze(-1).squeeze(-1)) + if return_norm_factor + else normalized_depth + ) + + return output + + +def normalize_pose_translations(pose_translations, return_norm_factor=False): + """ + Normalize the pose translations by the average norm of the non-zero pose translations. + + Args: + pose_translations (torch.Tensor): Pose translations tensor of size [B, V, 3]. B is the batch size, V is the number of views. + Returns: + normalized_pose_translations (torch.Tensor): Normalized pose translations tensor of size [B, V, 3]. + norm_factor (torch.Tensor): Norm factor tensor of size B. + """ + assert pose_translations.ndim == 3 and pose_translations.shape[2] == 3 + # Compute distance of all pose translations to origin + pose_translations_dis = pose_translations.norm(dim=-1) # [B, V] + non_zero_pose_translations_dis = pose_translations_dis > 0 # [B, V] + + # Calculate the average norm of the translations across all views (considering only views with non-zero translations) + sum_of_all_views_pose_translations = pose_translations_dis.sum(dim=1) # [B] + count_of_all_views_with_non_zero_pose_translations = ( + non_zero_pose_translations_dis.sum(dim=1) + ) # [B] + norm_factor = sum_of_all_views_pose_translations / ( + count_of_all_views_with_non_zero_pose_translations + 1e-8 + ) # [B] + + # Normalize the pose translations by the norm factor + norm_factor = norm_factor.clip(min=1e-8) + normalized_pose_translations = pose_translations / norm_factor.unsqueeze( + -1 + ).unsqueeze(-1) + + # Create the output tuple + output = ( + (normalized_pose_translations, norm_factor) + if return_norm_factor + else normalized_pose_translations + ) + + return output + + +def normalize_multiple_pointclouds( + pts_list, valid_masks=None, norm_mode="avg_dis", ret_factor=False +): + """ + Normalize multiple point clouds using a joint normalization strategy. + + Args: + pts_list: List of point clouds, each with shape (..., H, W, 3) or (B, H, W, 3) + valid_masks: Optional list of masks indicating valid points in each point cloud + norm_mode: String in format "{norm}_{dis}" where: + - norm: Normalization strategy (currently only "avg" is supported) + - dis: Distance transformation ("dis" for raw distance, "log1p" for log(1+distance), + "warp-log1p" to warp points using log distance) + ret_factor: If True, return the normalization factor as the last element in the result list + + Returns: + List of normalized point clouds with the same shapes as inputs. + If ret_factor is True, the last element is the normalization factor. + """ + assert all(pts.ndim >= 3 and pts.shape[-1] == 3 for pts in pts_list) + if valid_masks is not None: + assert len(pts_list) == len(valid_masks) + + norm_mode, dis_mode = norm_mode.split("_") + + # Gather all points together (joint normalization) + nan_pts_list = [ + invalid_to_zeros(pts, valid_masks[i], ndim=3) + if valid_masks + else invalid_to_zeros(pts, None, ndim=3) + for i, pts in enumerate(pts_list) + ] + all_pts = torch.cat([nan_pts for nan_pts, _ in nan_pts_list], dim=1) + nnz_list = [nnz for _, nnz in nan_pts_list] + + # Compute distance to origin + all_dis = all_pts.norm(dim=-1) + if dis_mode == "dis": + pass # do nothing + elif dis_mode == "log1p": + all_dis = torch.log1p(all_dis) + elif dis_mode == "warp-log1p": + # Warp input points before normalizing them + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + for i, pts in enumerate(pts_list): + H, W = pts.shape[1:-1] + pts_list[i] = pts * warp_factor[:, i * (H * W) : (i + 1) * (H * W)].view( + -1, H, W, 1 + ) + all_dis = log_dis + else: + raise ValueError(f"bad {dis_mode=}") + + # Compute normalization factor + norm_factor = all_dis.sum(dim=1) / (sum(nnz_list) + 1e-8) + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts_list[0].ndim: + norm_factor.unsqueeze_(-1) + + # Normalize points + res = [pts / norm_factor for pts in pts_list] + if ret_factor: + res.append(norm_factor) + + return res + + +def apply_log_to_norm(input_data): + """ + Normalize the input data and apply a logarithmic transformation based on the normalization factor. + + Args: + input_data (torch.Tensor): The input tensor to be normalized and transformed. + + Returns: + torch.Tensor: The transformed tensor after normalization and logarithmic scaling. + """ + org_d = input_data.norm(dim=-1, keepdim=True) + input_data = input_data / org_d.clip(min=1e-8) + input_data = input_data * torch.log1p(org_d) + return input_data + + +def angle_diff_vec3(v1, v2, eps=1e-12): + """ + Compute angle difference between 3D vectors. + + Args: + v1: torch.Tensor of shape (..., 3) + v2: torch.Tensor of shape (..., 3) + eps: Small epsilon value for numerical stability + + Returns: + torch.Tensor: Angle differences in radians + """ + cross_norm = torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps + dot_prod = (v1 * v2).sum(dim=-1) + return torch.atan2(cross_norm, dot_prod) + + +def angle_diff_vec3_numpy(v1: np.ndarray, v2: np.ndarray, eps: float = 1e-12): + """ + Compute angle difference between 3D vectors using NumPy. + + Args: + v1 (np.ndarray): First vector of shape (..., 3) + v2 (np.ndarray): Second vector of shape (..., 3) + eps (float, optional): Small epsilon value for numerical stability. Defaults to 1e-12. + + Returns: + np.ndarray: Angle differences in radians + """ + return np.arctan2( + np.linalg.norm(np.cross(v1, v2, axis=-1), axis=-1) + eps, (v1 * v2).sum(axis=-1) + ) + + +@no_warnings(category=RuntimeWarning) +def points_to_normals( + point: np.ndarray, mask: np.ndarray = None, edge_threshold: float = None +) -> np.ndarray: + """ + Calculate normal map from point map. Value range is [-1, 1]. + + Args: + point (np.ndarray): shape (height, width, 3), point map + mask (optional, np.ndarray): shape (height, width), dtype=bool. Mask of valid depth pixels. Defaults to None. + edge_threshold (optional, float): threshold for the angle (in degrees) between the normal and the view direction. Defaults to None. + + Returns: + normal (np.ndarray): shape (height, width, 3), normal map. + """ + height, width = point.shape[-3:-1] + has_mask = mask is not None + + if mask is None: + mask = np.ones_like(point[..., 0], dtype=bool) + mask_pad = np.zeros((height + 2, width + 2), dtype=bool) + mask_pad[1:-1, 1:-1] = mask + mask = mask_pad + + pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype) + pts[1:-1, 1:-1, :] = point + up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :] + left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :] + down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :] + right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :] + normal = np.stack( + [ + np.cross(up, left, axis=-1), + np.cross(left, down, axis=-1), + np.cross(down, right, axis=-1), + np.cross(right, up, axis=-1), + ] + ) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) + + valid = ( + np.stack( + [ + mask[:-2, 1:-1] & mask[1:-1, :-2], + mask[1:-1, :-2] & mask[2:, 1:-1], + mask[2:, 1:-1] & mask[1:-1, 2:], + mask[1:-1, 2:] & mask[:-2, 1:-1], + ] + ) + & mask[None, 1:-1, 1:-1] + ) + if edge_threshold is not None: + view_angle = angle_diff_vec3_numpy(pts[None, 1:-1, 1:-1, :], normal) + view_angle = np.minimum(view_angle, np.pi - view_angle) + valid = valid & (view_angle < np.deg2rad(edge_threshold)) + + normal = (normal * valid[..., None]).sum(axis=0) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) + + if has_mask: + normal_mask = valid.any(axis=0) + normal = np.where(normal_mask[..., None], normal, 0) + return normal, normal_mask + else: + return normal + + +def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1): + """ + Create a sliding window view of the input array along a specified axis. + + This function creates a memory-efficient view of the input array with sliding windows + of the specified size and stride. The window dimension is appended to the end of the + output array's shape. This is useful for operations like convolution, pooling, or + any analysis that requires examining local neighborhoods in the data. + + Args: + x (np.ndarray): Input array with shape (..., axis_size, ...) + window_size (int): Size of the sliding window + stride (int): Stride of the sliding window (step size between consecutive windows) + axis (int, optional): Axis to perform sliding window over. Defaults to -1 (last axis) + + Returns: + np.ndarray: View of the input array with shape (..., n_windows, ..., window_size), + where n_windows = (axis_size - window_size + 1) // stride + + Raises: + AssertionError: If window_size is larger than the size of the specified axis + + Example: + >>> x = np.array([1, 2, 3, 4, 5, 6]) + >>> sliding_window_1d(x, window_size=3, stride=2) + array([[1, 2, 3], + [3, 4, 5]]) + """ + assert x.shape[axis] >= window_size, ( + f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})" + ) + axis = axis % x.ndim + shape = ( + *x.shape[:axis], + (x.shape[axis] - window_size + 1) // stride, + *x.shape[axis + 1 :], + window_size, + ) + strides = ( + *x.strides[:axis], + stride * x.strides[axis], + *x.strides[axis + 1 :], + x.strides[axis], + ) + x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + return x_sliding + + +def sliding_window_nd( + x: np.ndarray, + window_size: Tuple[int, ...], + stride: Tuple[int, ...], + axis: Tuple[int, ...], +) -> np.ndarray: + """ + Create sliding windows along multiple dimensions of the input array. + + This function applies sliding_window_1d sequentially along multiple axes to create + N-dimensional sliding windows. This is useful for operations that need to examine + local neighborhoods in multiple dimensions simultaneously. + + Args: + x (np.ndarray): Input array + window_size (Tuple[int, ...]): Size of the sliding window for each axis + stride (Tuple[int, ...]): Stride of the sliding window for each axis + axis (Tuple[int, ...]): Axes to perform sliding window over + + Returns: + np.ndarray: Array with sliding windows along the specified dimensions. + The window dimensions are appended to the end of the shape. + + Note: + The length of window_size, stride, and axis tuples must be equal. + + Example: + >>> x = np.random.rand(10, 10) + >>> windows = sliding_window_nd(x, window_size=(3, 3), stride=(2, 2), axis=(-2, -1)) + >>> # Creates 3x3 sliding windows with stride 2 in both dimensions + """ + axis = [axis[i] % x.ndim for i in range(len(axis))] + for i in range(len(axis)): + x = sliding_window_1d(x, window_size[i], stride[i], axis[i]) + return x + + +def sliding_window_2d( + x: np.ndarray, + window_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + axis: Tuple[int, int] = (-2, -1), +) -> np.ndarray: + """ + Create 2D sliding windows over the input array. + + Convenience function for creating 2D sliding windows, commonly used for image + processing operations like convolution, pooling, or patch extraction. + + Args: + x (np.ndarray): Input array + window_size (Union[int, Tuple[int, int]]): Size of the 2D sliding window. + If int, same size is used for both dimensions. + stride (Union[int, Tuple[int, int]]): Stride of the 2D sliding window. + If int, same stride is used for both dimensions. + axis (Tuple[int, int], optional): Two axes to perform sliding window over. + Defaults to (-2, -1) (last two dimensions). + + Returns: + np.ndarray: Array with 2D sliding windows. The window dimensions (height, width) + are appended to the end of the shape. + + Example: + >>> image = np.random.rand(100, 100) + >>> patches = sliding_window_2d(image, window_size=8, stride=4) + >>> # Creates 8x8 patches with stride 4 from the image + """ + if isinstance(window_size, int): + window_size = (window_size, window_size) + if isinstance(stride, int): + stride = (stride, stride) + return sliding_window_nd(x, window_size, stride, axis) + + +def max_pool_1d( + x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1 +): + """ + Perform 1D max pooling on the input array. + + Max pooling reduces the dimensionality of the input by taking the maximum value + within each sliding window. This is commonly used in neural networks and signal + processing for downsampling and feature extraction. + + Args: + x (np.ndarray): Input array + kernel_size (int): Size of the pooling kernel + stride (int): Stride of the pooling operation + padding (int, optional): Amount of padding to add on both sides. Defaults to 0. + axis (int, optional): Axis to perform max pooling over. Defaults to -1. + + Returns: + np.ndarray: Max pooled array with reduced size along the specified axis + + Note: + - For floating point arrays, padding is done with np.nan values + - For integer arrays, padding is done with the minimum value of the dtype + - np.nanmax is used to handle NaN values in the computation + + Example: + >>> x = np.array([1, 3, 2, 4, 5, 1, 2]) + >>> max_pool_1d(x, kernel_size=3, stride=2) + array([3, 5, 2]) + """ + axis = axis % x.ndim + if padding > 0: + fill_value = np.nan if x.dtype.kind == "f" else np.iinfo(x.dtype).min + padding_arr = np.full( + (*x.shape[:axis], padding, *x.shape[axis + 1 :]), + fill_value=fill_value, + dtype=x.dtype, + ) + x = np.concatenate([padding_arr, x, padding_arr], axis=axis) + a_sliding = sliding_window_1d(x, kernel_size, stride, axis) + max_pool = np.nanmax(a_sliding, axis=-1) + return max_pool + + +def max_pool_nd( + x: np.ndarray, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + axis: Tuple[int, ...], +) -> np.ndarray: + """ + Perform N-dimensional max pooling on the input array. + + This function applies max_pool_1d sequentially along multiple axes to perform + multi-dimensional max pooling. This is useful for downsampling multi-dimensional + data while preserving the most important features. + + Args: + x (np.ndarray): Input array + kernel_size (Tuple[int, ...]): Size of the pooling kernel for each axis + stride (Tuple[int, ...]): Stride of the pooling operation for each axis + padding (Tuple[int, ...]): Amount of padding for each axis + axis (Tuple[int, ...]): Axes to perform max pooling over + + Returns: + np.ndarray: Max pooled array with reduced size along the specified axes + + Note: + The length of kernel_size, stride, padding, and axis tuples must be equal. + Max pooling is applied sequentially along each axis in the order specified. + + Example: + >>> x = np.random.rand(10, 10, 10) + >>> pooled = max_pool_nd(x, kernel_size=(2, 2, 2), stride=(2, 2, 2), + ... padding=(0, 0, 0), axis=(-3, -2, -1)) + >>> # Reduces each dimension by half with 2x2x2 max pooling + """ + for i in range(len(axis)): + x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i]) + return x + + +def max_pool_2d( + x: np.ndarray, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int]], + axis: Tuple[int, int] = (-2, -1), +): + """ + Perform 2D max pooling on the input array. + + Convenience function for 2D max pooling, commonly used in computer vision + and image processing for downsampling images while preserving important features. + + Args: + x (np.ndarray): Input array + kernel_size (Union[int, Tuple[int, int]]): Size of the 2D pooling kernel. + If int, same size is used for both dimensions. + stride (Union[int, Tuple[int, int]]): Stride of the 2D pooling operation. + If int, same stride is used for both dimensions. + padding (Union[int, Tuple[int, int]]): Amount of padding for both dimensions. + If int, same padding is used for both dimensions. + axis (Tuple[int, int], optional): Two axes to perform max pooling over. + Defaults to (-2, -1) (last two dimensions). + + Returns: + np.ndarray: 2D max pooled array with reduced size along the specified axes + + Example: + >>> image = np.random.rand(64, 64) + >>> pooled = max_pool_2d(image, kernel_size=2, stride=2, padding=0) + >>> # Reduces image size from 64x64 to 32x32 with 2x2 max pooling + """ + if isinstance(kernel_size, Number): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, Number): + stride = (stride, stride) + if isinstance(padding, Number): + padding = (padding, padding) + axis = tuple(axis) + return max_pool_nd(x, kernel_size, stride, padding, axis) + + +@no_warnings(category=RuntimeWarning) +def depth_edge( + depth: np.ndarray, + atol: float = None, + rtol: float = None, + kernel_size: int = 3, + mask: np.ndarray = None, +) -> np.ndarray: + """ + Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth. + + Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + if mask is None: + diff = max_pool_2d( + depth, kernel_size, stride=1, padding=kernel_size // 2 + ) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + else: + diff = max_pool_2d( + np.where(mask, depth, -np.inf), + kernel_size, + stride=1, + padding=kernel_size // 2, + ) + max_pool_2d( + np.where(mask, -depth, -np.inf), + kernel_size, + stride=1, + padding=kernel_size // 2, + ) + + edge = np.zeros_like(depth, dtype=bool) + if atol is not None: + edge |= diff > atol + + if rtol is not None: + edge |= diff / depth > rtol + return edge + + +def depth_aliasing( + depth: np.ndarray, + atol: float = None, + rtol: float = None, + kernel_size: int = 3, + mask: np.ndarray = None, +) -> np.ndarray: + """ + Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. + Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + if mask is None: + diff_max = ( + max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth + ) + diff_min = ( + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth + ) + else: + diff_max = ( + max_pool_2d( + np.where(mask, depth, -np.inf), + kernel_size, + stride=1, + padding=kernel_size // 2, + ) + - depth + ) + diff_min = ( + max_pool_2d( + np.where(mask, -depth, -np.inf), + kernel_size, + stride=1, + padding=kernel_size // 2, + ) + + depth + ) + diff = np.minimum(diff_max, diff_min) + + edge = np.zeros_like(depth, dtype=bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= diff / depth > rtol + return edge + + +@no_warnings(category=RuntimeWarning) +def normals_edge( + normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None +) -> np.ndarray: + """ + Compute the edge mask from normal map. + + Args: + normal (np.ndarray): shape (..., height, width, 3), normal map + tol (float): tolerance in degrees + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + assert normals.ndim >= 3 and normals.shape[-1] == 3, ( + "normal should be of shape (..., height, width, 3)" + ) + normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12) + + padding = kernel_size // 2 + normals_window = sliding_window_2d( + np.pad( + normals, + ( + *([(0, 0)] * (normals.ndim - 3)), + (padding, padding), + (padding, padding), + (0, 0), + ), + mode="edge", + ), + window_size=kernel_size, + stride=1, + axis=(-3, -2), + ) + if mask is None: + angle_diff = np.arccos( + (normals[..., None, None] * normals_window).sum(axis=-3) + ).max(axis=(-2, -1)) + else: + mask_window = sliding_window_2d( + np.pad( + mask, + (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)), + mode="edge", + ), + window_size=kernel_size, + stride=1, + axis=(-3, -2), + ) + angle_diff = np.where( + mask_window, + np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)), + 0, + ).max(axis=(-2, -1)) + + angle_diff = max_pool_2d( + angle_diff, kernel_size, stride=1, padding=kernel_size // 2 + ) + edge = angle_diff > np.deg2rad(tol) + return edge diff --git a/mapanything/utils/hf_utils/__init__.py b/mapanything/utils/hf_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/utils/hf_utils/__pycache__/__init__.cpython-312.pyc b/mapanything/utils/hf_utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e839fbfdc2137fc647b044105a4e56a4965cec20 Binary files /dev/null and b/mapanything/utils/hf_utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/mapanything/utils/hf_utils/__pycache__/css_and_html.cpython-312.pyc b/mapanything/utils/hf_utils/__pycache__/css_and_html.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74dee2b9b9ccad0f64d79dd3fed215203f2efd90 Binary files /dev/null and b/mapanything/utils/hf_utils/__pycache__/css_and_html.cpython-312.pyc differ diff --git a/mapanything/utils/hf_utils/__pycache__/hf_helpers.cpython-312.pyc b/mapanything/utils/hf_utils/__pycache__/hf_helpers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..015b2d2f6c20f851c92b868b5f399bfdfd13d7ee Binary files /dev/null and b/mapanything/utils/hf_utils/__pycache__/hf_helpers.cpython-312.pyc differ diff --git a/mapanything/utils/hf_utils/__pycache__/viz.cpython-312.pyc b/mapanything/utils/hf_utils/__pycache__/viz.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4475b84add90b82d5fd678fef9d968906d2f4caf Binary files /dev/null and b/mapanything/utils/hf_utils/__pycache__/viz.cpython-312.pyc differ diff --git a/mapanything/utils/hf_utils/css_and_html.py b/mapanything/utils/hf_utils/css_and_html.py new file mode 100644 index 0000000000000000000000000000000000000000..a303fb1a6061800dce0677bddc2eca4e8516dd35 --- /dev/null +++ b/mapanything/utils/hf_utils/css_and_html.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +CSS and HTML content for the MapAnything Gradio application. +This module contains all the CSS styles and HTML content blocks +used in the Gradio interface. +""" + +# CSS Styles for the Gradio interface +GRADIO_CSS = """ +.custom-log * { + font-style: italic; + font-size: 22px !important; + background-image: linear-gradient(120deg, #ffb366 0%, #ffa366 60%, #ff9966 100%); + -webkit-background-clip: text; + background-clip: text; + font-weight: bold !important; + color: transparent !important; + text-align: center !important; +} + +.example-log * { + font-style: italic; + font-size: 16px !important; + background-image: linear-gradient(120deg, #ffb366 0%, #ffa366 60%, #ff9966 100%); + -webkit-background-clip: text; + background-clip: text; + color: transparent !important; +} + +#my_radio .wrap { + display: flex; + flex-wrap: nowrap; + justify-content: center; + align-items: center; +} + +#my_radio .wrap label { + display: flex; + width: 50%; + justify-content: center; + align-items: center; + margin: 0; + padding: 10px 0; + box-sizing: border-box; +} + +/* Align navigation buttons with dropdown bottom */ +.navigation-row { + display: flex !important; + align-items: flex-end !important; + gap: 8px !important; +} + +.navigation-row > div:nth-child(1), +.navigation-row > div:nth-child(3) { + align-self: flex-end !important; +} + +.navigation-row > div:nth-child(2) { + flex: 1 !important; +} + +/* Make thumbnails clickable with pointer cursor */ +.clickable-thumbnail img { + cursor: pointer !important; +} + +.clickable-thumbnail:hover img { + cursor: pointer !important; + opacity: 0.8; + transition: opacity 0.3s ease; +} + +/* Make thumbnail containers narrower horizontally */ +.clickable-thumbnail { + padding: 5px 2px !important; + margin: 0 2px !important; +} + +.clickable-thumbnail .image-container { + margin: 0 !important; + padding: 0 !important; +} + +.scene-info { + text-align: center !important; + padding: 5px 2px !important; + margin: 0 !important; +} +""" + + +def get_header_html(logo_base64=None): + """ + Generate the main header HTML with logo and title. + + Args: + logo_base64 (str, optional): Base64 encoded logo image + + Returns: + str: HTML string for the header + """ + logo_style = "display: none;" if not logo_base64 else "" + logo_src = logo_base64 or "" + + return f""" +
+ WAI Logo +

MapAnything: Metric 3D Scene Reconstruction

+
+

+ 🌟 GitHub Repository | + 🚀 Project Page +

+ """ + + +def get_description_html(): + """ + Generate the main description and getting started HTML. + + Returns: + str: HTML string for the description + """ + return """ +
+

Upload a video or a set of images to create a 3D reconstruction of a scene or object. MapAnything takes these images and generates 3D point clouds directly from multi-view images.

+

This demo demonstrates the use of image inputs only. However, MapAnything is extremely flexible and supports any combination of inputs (images, calibration, poses & depth). For trying out memory efficient inference or additional inputs like cameras & depth, please check out the code in our Github repo.

+ +

Getting Started:

+
    +
  1. Upload Your Data: Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).
  2. +
  3. Preview: Your uploaded images will appear in the gallery on the left.
  4. +
  5. Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction process.
  6. +
  7. Visualize: The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.
  8. +
  9. + Adjust Reconstruction & Visualization (Optional): + You can fine-tune the visualization using the options below the viewer +
    + (click to expand): +
      +
    • Show Camera: Toggle the display of estimated camera positions.
    • +
    • Show Mesh: Use meshes for the prediction visualization.
    • +
    • Show Points from Frame: Select specific frames to display in the viewer.
    • +
    • Filter Black Background: Remove black background pixels.
    • +
    • Filter White Background: Remove white background pixels.
    • +
    +
    +
  10. +
+

Please note: The inference time changes based on the amount of input images, for e.g., less than 1 second for up to 50 views. However, downloading model weights and visualizing 3D points may take tens of seconds. Please be patient or, for faster visualization, use a local machine to run our demo from our GitHub repository.

+
+ """ + + +def get_acknowledgements_html(): + """ + Generate the acknowledgements section HTML. + + Returns: + str: HTML string for the acknowledgements + """ + return """ +
+
+

Acknowledgements

+

This site builds upon code from:

+ +

We extend our gratitude to these projects for their valuable contributions to the research community.

+
+ """ + + +def get_gradio_theme(): + """ + Get the configured Gradio theme. + + Returns: + gr.themes.Base: Configured Gradio theme + """ + import gradio as gr + + return gr.themes.Base( + primary_hue=gr.themes.Color( + c100="#ffedd5", + c200="#ffddb3", + c300="rgba(242.78125, 182.89427563548466, 120.32579495614034, 1)", + c400="#fb923c", + c50="#fff7ed", + c500="#f97316", + c600="#ea580c", + c700="#c2410c", + c800="#9a3412", + c900="#7c2d12", + c950="#6c2e12", + ), + secondary_hue="amber", + ) + + +# Measure tab instructions HTML +MEASURE_INSTRUCTIONS_HTML = """ +### Click on the image to measure the distance between two points. +""" diff --git a/mapanything/utils/hf_utils/hf_helpers.py b/mapanything/utils/hf_utils/hf_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..42932fc2b637ecea3e926f93c32b4278e5d04a31 --- /dev/null +++ b/mapanything/utils/hf_utils/hf_helpers.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Helper functions for HuggingFace integration and model initialization. +""" + +import json +import os + + +def load_hf_token(): + """Load HuggingFace access token from local file""" + # Also try environment variable + # see https://huggingface.co/docs/hub/spaces-overview#managing-secrets on options + token = ( + os.getenv("HF_TOKEN") + or os.getenv("HUGGING_FACE_HUB_TOKEN") + or os.getenv("HUGGING_FACE_MODEL_TOKEN") + ) + if token: + print("Loaded HuggingFace token from environment variable") + return token + + print( + "Warning: No HuggingFace token found. Model loading may fail for private repositories." + ) + return None + + +def init_hydra_config(config_path, overrides=None): + """Initialize Hydra config""" + import hydra + + config_dir = os.path.dirname(config_path) + config_name = os.path.basename(config_path).split(".")[0] + relative_path = os.path.relpath(config_dir, os.path.dirname(__file__)) + hydra.core.global_hydra.GlobalHydra.instance().clear() + hydra.initialize(version_base=None, config_path=relative_path) + if overrides is not None: + cfg = hydra.compose(config_name=config_name, overrides=overrides) + else: + cfg = hydra.compose(config_name=config_name) + return cfg + + +def initialize_mapanything_model(high_level_config, device): + """ + Initialize MapAnything model with three-tier fallback approach: + 1. Try HuggingFace from_pretrained() + 2. Download HF config + use local model factory + load HF weights + 3. Pure local configuration fallback + + Args: + high_level_config (dict): Configuration dictionary containing model settings + device (torch.device): Device to load the model on + + Returns: + torch.nn.Module: Initialized MapAnything model + """ + import torch + from huggingface_hub import hf_hub_download + + from mapanything.models import init_model, MapAnything + + print("Initializing MapAnything model...") + + # Initialize Hydra config and create model from configuration + cfg = init_hydra_config( + high_level_config["path"], overrides=high_level_config["config_overrides"] + ) + + # Try using from_pretrained first + try: + print("Loading MapAnything model from_pretrained...") + model = MapAnything.from_pretrained(high_level_config["hf_model_name"]).to( + device + ) + print("Loading MapAnything model from_pretrained succeeded...") + return model + except Exception as e: + print(f"from_pretrained failed: {e}") + print("Falling back to local configuration approach using hf_hub_download...") + + # Create model from local configuration instead of using from_pretrained + # Try to download and use the config from HuggingFace Hub + try: + print("Downloading model configuration from HuggingFace Hub...") + config_path = hf_hub_download( + repo_id=high_level_config["hf_model_name"], + filename=high_level_config["config_name"], + token=load_hf_token(), + ) + + # Load the config from the downloaded file + with open(config_path, "r") as f: + downloaded_config = json.load(f) + + print("Using downloaded configuration for model initialization") + model = init_model( + model_str=downloaded_config.get( + "model_str", high_level_config["model_str"] + ), + model_config=downloaded_config.get( + "model_config", cfg.model.model_config + ), + torch_hub_force_reload=high_level_config.get( + "torch_hub_force_reload", False + ), + ) + except Exception as config_e: + print(f"Failed to download/use HuggingFace config: {config_e}") + print("Falling back to local configuration...") + # Fall back to local configuration as before + model = init_model( + model_str=cfg.model.model_str, + model_config=cfg.model.model_config, + torch_hub_force_reload=high_level_config.get( + "torch_hub_force_reload", False + ), + ) + + # Load the pretrained weights from HuggingFace Hub + try: + # First, let's see what files are available in the repository + try: + checkpoint_filename = high_level_config["checkpoint_name"] + # Download the model weights + checkpoint_path = hf_hub_download( + repo_id=high_level_config["hf_model_name"], + filename=checkpoint_filename, + token=load_hf_token(), + ) + + # Load the weights + print("start loading checkpoint") + if checkpoint_filename.endswith(".safetensors"): + from safetensors.torch import load_file + + checkpoint = load_file(checkpoint_path) + else: + checkpoint = torch.load( + checkpoint_path, map_location="cpu", weights_only=False + ) + + print("start loading state_dict") + if "model" in checkpoint: + model.load_state_dict(checkpoint["model"], strict=False) + elif "state_dict" in checkpoint: + model.load_state_dict(checkpoint["state_dict"], strict=False) + else: + model.load_state_dict(checkpoint, strict=False) + + print( + f"Successfully loaded pretrained weights from HuggingFace Hub ({checkpoint_filename})" + ) + + except Exception as inner_e: + print(f"Error listing repository files or loading weights: {inner_e}") + raise inner_e + + except Exception as e: + print(f"Warning: Could not load pretrained weights: {e}") + print("Proceeding with randomly initialized model...") + + model = model.to(device) + return model + + +def initialize_mapanything_local(local_config, device): + """Initialize a MapAnything model entirely from local resources. + + Args: + local_config (dict): + - path (str): Path to the Hydra config (for example ``configs/train.yaml``). + - checkpoint_path (str): Local path to the pretrained checkpoint. + - config_overrides (list[str], optional): Hydra override strings. + - config_json_path (str, optional): JSON file containing ``model_str``/``model_config`` overrides. + - model_str (str, optional): Model alias if not provided by the JSON/config (defaults to Hydra config value). + - torch_hub_force_reload (bool, optional): Forwarded to ``init_model``. + - strict (bool, optional): ``load_state_dict`` strict flag, defaults to False so older checkpoints remain compatible. + device (torch.device | str): Target device that will host the model. + + Returns: + torch.nn.Module: MapAnything model moved to ``device`` and switched to ``eval()``. + + Raises: + FileNotFoundError: Raised when the JSON config or checkpoint cannot be found. + """ + + if "path" not in local_config or "checkpoint_path" not in local_config: + raise ValueError("local_config must provide both 'path' and 'checkpoint_path'") + + import torch + + from mapanything.models import init_model + + config_overrides = local_config.get("config_overrides") + cfg = init_hydra_config(local_config["path"], overrides=config_overrides) + + model_config_json = None + config_json_path = local_config.get("config_json_path") + if config_json_path: + if not os.path.exists(config_json_path): + raise FileNotFoundError(f"Config JSON not found: {config_json_path}") + with open(config_json_path, "r") as f: + model_config_json = json.load(f) + + model_str = None + model_config = None + if model_config_json: + model_str = model_config_json.get("model_str") + model_config = model_config_json.get("model_config") + + if model_str is None: + model_str = local_config.get("model_str", cfg.model.model_str) + + if model_config is None: + model_config = local_config.get("model_config", cfg.model.model_config) + + torch_hub_force_reload = local_config.get("torch_hub_force_reload", False) + + model = init_model( + model_str=model_str, + model_config=model_config, + torch_hub_force_reload=torch_hub_force_reload, + ) + + checkpoint_path = local_config["checkpoint_path"] + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + if checkpoint_path.endswith(".safetensors"): + from safetensors.torch import load_file as load_safetensors + + checkpoint = load_safetensors(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + strict = local_config.get("strict", False) + if isinstance(checkpoint, dict): + if "model" in checkpoint: + state_dict = checkpoint["model"] + elif "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + else: + state_dict = checkpoint + + model.load_state_dict(state_dict, strict=strict) + + model = model.to(device).eval() + return model diff --git a/mapanything/utils/hf_utils/viz.py b/mapanything/utils/hf_utils/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf8292a2cb2072e140401184c6d40db3be5944a --- /dev/null +++ b/mapanything/utils/hf_utils/viz.py @@ -0,0 +1,706 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utility functions for Gradio demo visualizations +""" + +import copy +import os +from typing import Tuple + +import cv2 +import matplotlib +import numpy as np +import requests +import trimesh +from scipy.spatial.transform import Rotation + + +def remove_unreferenced_vertices( + faces: np.ndarray, *vertice_attrs, return_indices: bool = False +) -> Tuple[np.ndarray, ...]: + """ + Remove unreferenced vertices of a mesh. + Unreferenced vertices are removed, and the face indices are updated accordingly. + + Args: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + + Returns: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None. + """ + P = faces.shape[-1] + fewer_indices, inv_map = np.unique(faces, return_inverse=True) + faces = inv_map.astype(np.int32).reshape(-1, P) + ret = [faces] + for attr in vertice_attrs: + ret.append(attr[fewer_indices]) + if return_indices: + ret.append(fewer_indices) + return tuple(ret) + + +def triangulate( + faces: np.ndarray, vertices: np.ndarray = None, backslash: np.ndarray = None +) -> np.ndarray: + """ + Triangulate a polygonal mesh. + + Args: + faces (np.ndarray): [L, P] polygonal faces + vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (np.ndarray, optional): [L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + + Returns: + (np.ndarray): [L * (P - 2), 3] triangular faces + """ + if faces.shape[-1] == 3: + return faces + P = faces.shape[-1] + if vertices is not None: + assert faces.shape[-1] == 4, "now only support quad mesh" + if backslash is None: + backslash = np.linalg.norm( + vertices[faces[:, 0]] - vertices[faces[:, 2]], axis=-1 + ) < np.linalg.norm(vertices[faces[:, 1]] - vertices[faces[:, 3]], axis=-1) + if backslash is None: + loop_indice = np.stack( + [ + np.zeros(P - 2, dtype=int), + np.arange(1, P - 1, 1, dtype=int), + np.arange(2, P, 1, dtype=int), + ], + axis=1, + ) + return faces[:, loop_indice].reshape((-1, 3)) + else: + assert faces.shape[-1] == 4, "now only support quad mesh" + faces = np.where( + backslash[:, None], + faces[:, [0, 1, 2, 0, 2, 3]], + faces[:, [0, 1, 3, 3, 1, 2]], + ).reshape((-1, 3)) + return faces + + +def image_mesh( + *image_attrs: np.ndarray, + mask: np.ndarray = None, + tri: bool = False, + return_indices: bool = False, +) -> Tuple[np.ndarray, ...]: + """ + Get a mesh regarding image pixel uv coordinates as vertices and image grid as faces. + + Args: + *image_attrs (np.ndarray): image attributes in shape (height, width, [channels]) + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + + Returns: + faces (np.ndarray): faces connecting neighboring pixels. shape (T, 4) if tri is False, else (T, 3) + *vertex_attrs (np.ndarray): vertex attributes in corresponding order with input image_attrs + indices (np.ndarray, optional): indices of vertices in the original mesh + """ + assert (len(image_attrs) > 0) or (mask is not None), ( + "At least one of image_attrs or mask should be provided" + ) + height, width = next(image_attrs).shape[:2] if mask is None else mask.shape + assert all(img.shape[:2] == (height, width) for img in image_attrs), ( + "All image_attrs should have the same shape" + ) + + row_faces = np.stack( + [ + np.arange(0, width - 1, dtype=np.int32), + np.arange(width, 2 * width - 1, dtype=np.int32), + np.arange(1 + width, 2 * width, dtype=np.int32), + np.arange(1, width, dtype=np.int32), + ], + axis=1, + ) + faces = ( + np.arange(0, (height - 1) * width, width, dtype=np.int32)[:, None, None] + + row_faces[None, :, :] + ).reshape((-1, 4)) + if mask is None: + if tri: + faces = triangulate(faces) + ret = [faces, *(img.reshape(-1, *img.shape[2:]) for img in image_attrs)] + if return_indices: + ret.append(np.arange(height * width, dtype=np.int32)) + return tuple(ret) + else: + quad_mask = ( + mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:] + ).ravel() + faces = faces[quad_mask] + if tri: + faces = triangulate(faces) + return remove_unreferenced_vertices( + faces, + *(x.reshape(-1, *x.shape[2:]) for x in image_attrs), + return_indices=return_indices, + ) + + +def predictions_to_glb( + predictions, + filter_by_frames="all", + mask_black_bg=False, + mask_white_bg=False, + show_cam=True, + mask_ambiguous=False, + as_mesh=True, + conf_percentile=None, +) -> trimesh.Scene: + """ + Converts MapAnything predictions to a 3D scene represented as a GLB file. + + Args: + predictions (dict): Dictionary containing model predictions with keys: + - world_points: 3D point coordinates (S, H, W, 3) + - images: Input images (S, H, W, 3) + - extrinsic: Camera extrinsic matrices (S, 3, 4) + filter_by_frames (str): Frame filter specification (default: "all") + mask_black_bg (bool): Mask out black background pixels (default: False) + mask_white_bg (bool): Mask out white background pixels (default: False) + show_cam (bool): Include camera visualization (default: True) + mask_ambiguous (bool): Apply final mask to filter ambiguous predictions (default: False) + as_mesh (bool): Represent the data as a mesh instead of point cloud (default: False) + + Returns: + trimesh.Scene: Processed 3D scene containing point cloud/mesh and cameras + + Raises: + ValueError: If input predictions structure is invalid + """ + if not isinstance(predictions, dict): + raise ValueError("predictions must be a dictionary") + + print("Building GLB scene") + selected_frame_idx = None + if filter_by_frames != "all" and filter_by_frames != "All": + try: + # Extract the index part before the colon + selected_frame_idx = int(filter_by_frames.split(":")[0]) + except (ValueError, IndexError): + pass + + # Always use Pointmap Branch + print("Using Pointmap Branch") + if "world_points" not in predictions: + raise ValueError( + "world_points not found in predictions. Pointmap Branch requires 'world_points' key. " + "Depthmap and Camera branches have been removed." + ) + + pred_world_points = predictions["world_points"] + + # Get images from predictions + images = predictions["images"] + # Use extrinsic matrices instead of pred_extrinsic_list + camera_matrices = predictions["extrinsic"] + + if selected_frame_idx is not None: + pred_world_points = pred_world_points[selected_frame_idx][None] + images = images[selected_frame_idx][None] + camera_matrices = camera_matrices[selected_frame_idx][None] + + vertices_3d = pred_world_points.reshape(-1, 3) + # Handle different image formats - check if images need transposing + if images.ndim == 4 and images.shape[1] == 3: # NCHW format + colors_rgb = np.transpose(images, (0, 2, 3, 1)) + else: # Assume already in NHWC format + colors_rgb = images + colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) + + # Create mask for filtering + mask = np.ones(len(vertices_3d), dtype=bool) + final_mask = predictions["final_mask"].reshape(-1) + + # Confidence masking + if conf_percentile is not None and "conf" in predictions: + # print ("Applying confidence masking...") + conf = predictions["conf"].reshape(-1) + threshold = np.percentile(conf, conf_percentile) + # print (f"Confidence threshold at {conf_percentile} percentile: {threshold}") + conf_mask = conf >= threshold + mask = mask & conf_mask + + if mask_black_bg: + black_bg_mask = colors_rgb.sum(axis=1) >= 16 + mask = mask & black_bg_mask + + if mask_white_bg: + # Filter out white background pixels (RGB values close to white) + # Consider pixels white if all RGB values are above 240 + white_bg_mask = ( + (colors_rgb[:, 0] > 240) + & (colors_rgb[:, 1] > 240) + & (colors_rgb[:, 2] > 240) + ) + mask = mask & ~white_bg_mask + + # Use final_mask when mask_ambiguous is checked + if mask_ambiguous: + mask = mask & final_mask + + vertices_3d = vertices_3d[mask].copy() + colors_rgb = colors_rgb[mask].copy() + + if vertices_3d is None or np.asarray(vertices_3d).size == 0: + vertices_3d = np.array([[1, 0, 0]]) + colors_rgb = np.array([[255, 255, 255]]) + scene_scale = 1 + else: + # Calculate the 5th and 95th percentiles along each axis + lower_percentile = np.percentile(vertices_3d, 5, axis=0) + upper_percentile = np.percentile(vertices_3d, 95, axis=0) + + # Calculate the diagonal length of the percentile bounding box + scene_scale = np.linalg.norm(upper_percentile - lower_percentile) + + colormap = matplotlib.colormaps.get_cmap("gist_rainbow") + + # Initialize a 3D scene + scene_3d = trimesh.Scene() + + # Add point cloud data to the scene + if as_mesh: + # Create mesh from pointcloud + # try: + if selected_frame_idx is not None: + # Single frame case - we can create a proper mesh + H, W = pred_world_points.shape[1:3] + + # Get original unfiltered data for mesh creation + original_points = pred_world_points.reshape(H, W, 3) + + # Reshape original image data properly + if images.ndim == 4 and images.shape[1] == 3: # NCHW format + original_image_colors = np.transpose(images[0], (1, 2, 0)) + else: # Assume already in HWC format + original_image_colors = images[0] + original_image_colors *= 255 + # Get original final mask + original_final_mask = predictions["final_mask"][selected_frame_idx].reshape( + H, W + ) + + # Create mask based on final mask + mask = original_final_mask + + # Confidence masking + if conf_percentile is not None and "conf" in predictions: + # print ("Applying confidence masking...") + conf = predictions["conf"][selected_frame_idx].reshape(-1) + threshold = np.percentile(conf, conf_percentile) + # print (f"Confidence threshold at {conf_percentile} percentile: {threshold}") + conf_mask = conf >= threshold + mask = mask & conf_mask.reshape(H, W) + + # Additional background masks if needed + if mask_black_bg: + black_bg_mask = original_image_colors.sum(axis=2) >= 16 + mask = mask & black_bg_mask + + if mask_white_bg: + white_bg_mask = ~( + (original_image_colors[:, :, 0] > 240) + & (original_image_colors[:, :, 1] > 240) + & (original_image_colors[:, :, 2] > 240) + ) + mask = mask & white_bg_mask + + # Check if normals are available in predictions + vertex_normals = None + if "normal" in predictions and predictions["normal"] is not None: + # Get normals for the selected frame + frame_normals = ( + predictions["normal"][selected_frame_idx] + if selected_frame_idx is not None + else predictions["normal"][0] + ) + + # Create faces and vertices using image_mesh with normals support + faces, vertices, vertex_colors, vertex_normals = image_mesh( + original_points * np.array([1, -1, 1], dtype=np.float32), + original_image_colors / 255.0, + frame_normals * np.array([1, -1, 1], dtype=np.float32), + mask=mask, + tri=True, + return_indices=False, + ) + + # Apply coordinate transformations to normals + vertex_normals = vertex_normals * np.array([1, -1, 1], dtype=np.float32) + else: + # Create faces and vertices using image_mesh without normals + faces, vertices, vertex_colors = image_mesh( + original_points * np.array([1, -1, 1], dtype=np.float32), + original_image_colors / 255.0, + mask=mask, + tri=True, + return_indices=False, + ) + + # vertices = vertices * np.array([1, -1, 1], dtype=np.float32) + + # Create trimesh object with optional normals + mesh_data = trimesh.Trimesh( + vertices=vertices * np.array([1, -1, 1], dtype=np.float32), + faces=faces, + vertex_colors=(vertex_colors * 255).astype(np.uint8), + vertex_normals=(vertex_normals if vertex_normals is not None else None), + process=False, + ) + scene_3d.add_geometry(mesh_data) + + else: + # Multi-frame case - create separate meshes for each frame + print("Creating mesh for multi-frame data...") + + for frame_idx in range(pred_world_points.shape[0]): + H, W = pred_world_points.shape[1:3] + + # Get data for this frame + frame_points = pred_world_points[frame_idx] + frame_final_mask = predictions["final_mask"][frame_idx] + + # Get frame image + if images.ndim == 4 and images.shape[1] == 3: # NCHW format + frame_image = np.transpose(images[frame_idx], (1, 2, 0)) + else: # Assume already in HWC format + frame_image = images[frame_idx] + frame_image *= 255 + # Create mask for this frame using final_mask + mask = frame_final_mask + + # Additional background masks if needed + if mask_black_bg: + black_bg_mask = frame_image.sum(axis=2) >= 16 + mask = mask & black_bg_mask + + if mask_white_bg: + white_bg_mask = ~( + (frame_image[:, :, 0] > 240) + & (frame_image[:, :, 1] > 240) + & (frame_image[:, :, 2] > 240) + ) + mask = mask & white_bg_mask + if conf_percentile is not None and "conf" in predictions: + # print ("Applying confidence masking...") + conf = predictions["conf"][frame_idx].reshape(-1) + threshold = np.percentile(conf, conf_percentile) + # print (f"Confidence threshold at {conf_percentile} percentile: {threshold}") + conf_mask = conf >= threshold + mask = mask & conf_mask.reshape(H, W) + # Create mesh for this frame + faces, vertices, vertex_colors = image_mesh( + frame_points * np.array([1, -1, 1], dtype=np.float32), + frame_image / 255.0, + mask=mask, + tri=True, + return_indices=False, + ) + + vertices = vertices * np.array([1, -1, 1], dtype=np.float32) + # Create trimesh object for this frame + frame_mesh = trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_colors=(vertex_colors * 255).astype(np.uint8), + process=False, + ) + scene_3d.add_geometry(frame_mesh) + else: + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) + scene_3d.add_geometry(point_cloud_data) + + # Prepare 4x4 matrices for camera extrinsics + num_cameras = len(camera_matrices) + + if show_cam: + # Add camera models to the scene + for i in range(num_cameras): + world_to_camera = camera_matrices[i] + rgba_color = colormap(i / num_cameras) + current_color = tuple(int(255 * x) for x in rgba_color[:3]) + + integrate_camera_into_scene( + scene_3d, world_to_camera, current_color, scene_scale + ) + + # Align scene to the observation of the first camera + scene_3d = apply_scene_alignment(scene_3d, camera_matrices) + + print("GLB Scene built") + return scene_3d + + +def integrate_camera_into_scene( + scene: trimesh.Scene, + transform: np.ndarray, + face_colors: tuple, + scene_scale: float, +): + """ + Integrates a fake camera mesh into the 3D scene. + + Args: + scene (trimesh.Scene): The 3D scene to add the camera model. + transform (np.ndarray): Transformation matrix for camera positioning. + face_colors (tuple): Color of the camera face. + scene_scale (float): Scale of the scene. + """ + scene_scale = 12 + cam_width = scene_scale * 0.05 + cam_height = scene_scale * 0.1 + # cam_width = scene_scale * 0.05 + # cam_height = scene_scale * 0.1 + + # Create cone shape for camera + rot_45_degree = np.eye(4) + rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix() + rot_45_degree[2, 3] = -cam_height + + opengl_transform = get_opengl_conversion_matrix() + # Combine transformations + complete_transform = transform @ opengl_transform @ rot_45_degree + camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4) + + # Generate mesh for the camera + slight_rotation = np.eye(4) + slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix() + + vertices_combined = np.concatenate( + [ + camera_cone_shape.vertices, + 0.95 * camera_cone_shape.vertices, + transform_points(slight_rotation, camera_cone_shape.vertices), + ] + ) + vertices_transformed = transform_points(complete_transform, vertices_combined) + + mesh_faces = compute_camera_faces(camera_cone_shape) + + # Add the camera mesh to the scene + camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces) + camera_mesh.visual.face_colors[:, :3] = face_colors + scene.add_geometry(camera_mesh) + + +def apply_scene_alignment( + scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray +) -> trimesh.Scene: + """ + Aligns the 3D scene based on the extrinsics of the first camera. + + Args: + scene_3d (trimesh.Scene): The 3D scene to be aligned. + extrinsics_matrices (np.ndarray): Camera extrinsic matrices. + + Returns: + trimesh.Scene: Aligned 3D scene. + """ + # Set transformations for scene alignment + opengl_conversion_matrix = get_opengl_conversion_matrix() + + # Rotation matrix for alignment (180 degrees around the y-axis) + align_rotation = np.eye(4) + align_rotation[:3, :3] = Rotation.from_euler("y", 0, degrees=True).as_matrix() + + # Apply transformation + initial_transformation = ( + np.linalg.inv(extrinsics_matrices[0]) + @ opengl_conversion_matrix + @ align_rotation + ) + scene_3d.apply_transform(initial_transformation) + return scene_3d + + +def get_opengl_conversion_matrix() -> np.ndarray: + """ + Constructs and returns the OpenGL conversion matrix. + + Returns: + numpy.ndarray: A 4x4 OpenGL conversion matrix. + """ + # Create an identity matrix + matrix = np.identity(4) + + # Flip the y and z axes + matrix[1, 1] = -1 + matrix[2, 2] = -1 + + return matrix + + +def transform_points( + transformation: np.ndarray, points: np.ndarray, dim: int = None +) -> np.ndarray: + """ + Applies a 4x4 transformation to a set of points. + + Args: + transformation (np.ndarray): Transformation matrix. + points (np.ndarray): Points to be transformed. + dim (int, optional): Dimension for reshaping the result. + + Returns: + np.ndarray: Transformed points. + """ + points = np.asarray(points) + initial_shape = points.shape[:-1] + dim = dim or points.shape[-1] + + # Apply transformation + transformation = transformation.swapaxes( + -1, -2 + ) # Transpose the transformation matrix + points = points @ transformation[..., :-1, :] + transformation[..., -1:, :] + + # Reshape the result + result = points[..., :dim].reshape(*initial_shape, dim) + return result + + +def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray: + """ + Computes the faces for the camera mesh. + + Args: + cone_shape (trimesh.Trimesh): The shape of the camera cone. + + Returns: + np.ndarray: Array of faces for the camera mesh. + """ + # Create pseudo cameras + faces_list = [] + num_vertices_cone = len(cone_shape.vertices) + + for face in cone_shape.faces: + if 0 in face: + continue + v1, v2, v3 = face + v1_offset, v2_offset, v3_offset = face + num_vertices_cone + v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone + + faces_list.extend( + [ + (v1, v2, v2_offset), + (v1, v1_offset, v3), + (v3_offset, v2, v3), + (v1, v2, v2_offset_2), + (v1, v1_offset_2, v3), + (v3_offset_2, v2, v3), + ] + ) + + faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list] + return np.array(faces_list) + + +def segment_sky(image_path, onnx_session, mask_filename=None): + """ + Segments sky from an image using an ONNX model. + Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing + + Args: + image_path: Path to input image + onnx_session: ONNX runtime session with loaded model + mask_filename: Path to save the output mask + + Returns: + np.ndarray: Binary mask where 255 indicates non-sky regions + """ + + assert mask_filename is not None + image = cv2.imread(image_path) + + result_map = run_skyseg(onnx_session, [320, 320], image) + # resize the result_map to the original image size + result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0])) + + # Fix: Invert the mask so that 255 = non-sky, 0 = sky + # The model outputs low values for sky, high values for non-sky + output_mask = np.zeros_like(result_map_original) + output_mask[result_map_original < 32] = 255 # Use threshold of 32 + + os.makedirs(os.path.dirname(mask_filename), exist_ok=True) + cv2.imwrite(mask_filename, output_mask) + return output_mask + + +def run_skyseg(onnx_session, input_size, image): + """ + Runs sky segmentation inference using ONNX model. + + Args: + onnx_session: ONNX runtime session + input_size: Target size for model input (width, height) + image: Input image in BGR format + + Returns: + np.ndarray: Segmentation mask + """ + + # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast + temp_image = copy.deepcopy(image) + resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) + x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) + x = np.array(x, dtype=np.float32) + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + x = (x / 255 - mean) / std + x = x.transpose(2, 0, 1) + x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") + + # Inference + input_name = onnx_session.get_inputs()[0].name + output_name = onnx_session.get_outputs()[0].name + onnx_result = onnx_session.run([output_name], {input_name: x}) + + # Post process + onnx_result = np.array(onnx_result).squeeze() + min_value = np.min(onnx_result) + max_value = np.max(onnx_result) + onnx_result = (onnx_result - min_value) / (max_value - min_value) + onnx_result *= 255 + onnx_result = onnx_result.astype("uint8") + + return onnx_result + + +def download_file_from_url(url, filename): + """Downloads a file from a Hugging Face model repo, handling redirects.""" + try: + # Get the redirect URL + response = requests.get(url, allow_redirects=False) + response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx) + + if response.status_code == 302: # Expecting a redirect + redirect_url = response.headers["Location"] + response = requests.get(redirect_url, stream=True) + response.raise_for_status() + else: + print(f"Unexpected status code: {response.status_code}") + return + + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f"Downloaded {filename} successfully.") + + except requests.exceptions.RequestException as e: + print(f"Error downloading file: {e}") diff --git a/mapanything/utils/image.py b/mapanything/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..0542004c26b29a1dbc7047089a4956b1ad166035 --- /dev/null +++ b/mapanything/utils/image.py @@ -0,0 +1,675 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utility functions for loading, converting, and manipulating images. + +This module provides functions for: +- Converting between different image formats and representations +- Resizing and cropping images to specific resolutions +- Loading and normalizing images for model input +- Handling various image file formats including HEIF/HEIC when available +""" + +import os + +import numpy as np +import PIL.Image +import torch +import torchvision.transforms as tvf +from PIL.ImageOps import exif_transpose + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +try: + from pillow_heif import register_heif_opener + + register_heif_opener() + heif_support_enabled = True +except ImportError: + heif_support_enabled = False + +from mapanything.utils.cropping import crop_resize_if_necessary +from mapanything.utils.geometry import recover_pinhole_intrinsics_from_ray_directions +from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT + +# Fixed resolution mappings with precomputed aspect ratios as keys +RESOLUTION_MAPPINGS = { + 518: { + 1.000: (518, 518), # 1:1 + 1.321: (518, 392), # 4:3 + 1.542: (518, 336), # 3:2 + 1.762: (518, 294), # 16:9 + 2.056: (518, 252), # 2:1 + 3.083: (518, 168), # 3.2:1 + 0.757: (392, 518), # 3:4 + 0.649: (336, 518), # 2:3 + 0.567: (294, 518), # 9:16 + 0.486: (252, 518), # 1:2 + }, + 512: { + 1.000: (512, 512), # 1:1 + 1.333: (512, 384), # 4:3 + 1.524: (512, 336), # 3:2 + 1.778: (512, 288), # 16:9 + 2.000: (512, 256), # 2:1 + 3.200: (512, 160), # 3.2:1 + 0.750: (384, 512), # 3:4 + 0.656: (336, 512), # 2:3 + 0.562: (288, 512), # 9:16 + 0.500: (256, 512), # 1:2 + }, +} + +# Precomputed sorted aspect ratio keys for efficient lookup +ASPECT_RATIO_KEYS = { + 518: sorted(RESOLUTION_MAPPINGS[518].keys()), + 512: sorted(RESOLUTION_MAPPINGS[512].keys()), +} + + +def find_closest_aspect_ratio(aspect_ratio, resolution_set): + """ + Find the closest aspect ratio from the resolution mappings using efficient key lookup. + + Args: + aspect_ratio (float): Target aspect ratio + resolution_set (int): Resolution set to use (518 or 512) + + Returns: + tuple: (target_width, target_height) from the resolution mapping + """ + aspect_keys = ASPECT_RATIO_KEYS[resolution_set] + + # Find the closest aspect ratio key using binary search approach + closest_key = min(aspect_keys, key=lambda x: abs(x - aspect_ratio)) + + return RESOLUTION_MAPPINGS[resolution_set][closest_key] + + +def rgb(ftensor, norm_type, true_shape=None): + """ + Convert normalized image tensor to RGB image for visualization. + + Args: + ftensor (torch.Tensor or numpy.ndarray or list): Image tensor or list of image tensors + norm_type (str): Normalization type, see UniCeption IMAGE_NORMALIZATION_DICT keys or use "identity" + true_shape (tuple, optional): If provided, the image will be cropped to this shape (H, W) + + Returns: + numpy.ndarray: RGB image with values in range [0, 1] + """ + if isinstance(ftensor, list): + return [rgb(x, norm_type, true_shape=true_shape) for x in ftensor] + if isinstance(ftensor, torch.Tensor): + ftensor = ftensor.detach().cpu().numpy() # H,W,3 + if ftensor.ndim == 3 and ftensor.shape[0] == 3: + ftensor = ftensor.transpose(1, 2, 0) + elif ftensor.ndim == 4 and ftensor.shape[1] == 3: + ftensor = ftensor.transpose(0, 2, 3, 1) + if true_shape is not None: + H, W = true_shape + ftensor = ftensor[:H, :W] + if ftensor.dtype == np.uint8: + img = np.float32(ftensor) / 255 + else: + if norm_type in IMAGE_NORMALIZATION_DICT.keys(): + img_norm = IMAGE_NORMALIZATION_DICT[norm_type] + mean = img_norm.mean.numpy() + std = img_norm.std.numpy() + elif norm_type == "identity": + mean = 0.0 + std = 1.0 + else: + raise ValueError( + f"Unknown image normalization type: {norm_type}. Available types: identity or {IMAGE_NORMALIZATION_DICT.keys()}" + ) + img = ftensor * std + mean + return img.clip(min=0, max=1) + + +def load_images( + folder_or_list, + resize_mode="fixed_mapping", + size=None, + norm_type="dinov2", + patch_size=14, + verbose=False, + bayer_format=False, + resolution_set=518, + stride=1, +): + """ + Open and convert all images in a list or folder to proper input format for model + + Args: + folder_or_list (str or list): Path to folder or list of image paths. + resize_mode (str): Resize mode - "fixed_mapping", "longest_side", "square", or "fixed_size". Defaults to "fixed_mapping". + size (int or tuple, optional): Required for "longest_side", "square", and "fixed_size" modes. + - For "longest_side" and "square": int value for resize dimension + - For "fixed_size": tuple of (width, height) + norm_type (str, optional): Image normalization type. See UniCeption IMAGE_NORMALIZATION_DICT keys. Defaults to "dinov2". + patch_size (int, optional): Patch size for image processing. Defaults to 14. + verbose (bool, optional): If True, print progress messages. Defaults to False. + bayer_format (bool, optional): If True, read images in Bayer format. Defaults to False. + resolution_set (int, optional): Resolution set to use for "fixed_mapping" mode (518 or 512). Defaults to 518. + stride (int, optional): Load every nth image from the input. stride=1 loads all images, stride=2 loads every 2nd image, etc. Defaults to 1. + + Returns: + list: List of dictionaries containing image data and metadata + """ + # Validate resize_mode and size parameter requirements + valid_resize_modes = ["fixed_mapping", "longest_side", "square", "fixed_size"] + if resize_mode not in valid_resize_modes: + raise ValueError( + f"Resize_mode must be one of {valid_resize_modes}, got '{resize_mode}'" + ) + + if resize_mode in ["longest_side", "square", "fixed_size"] and size is None: + raise ValueError(f"Size parameter is required for resize_mode='{resize_mode}'") + + # Validate size type based on resize mode + if resize_mode in ["longest_side", "square"]: + if not isinstance(size, int): + raise ValueError( + f"Size must be an int for resize_mode='{resize_mode}', got {type(size)}" + ) + elif resize_mode == "fixed_size": + if not isinstance(size, (tuple, list)) or len(size) != 2: + raise ValueError( + f"Size must be a tuple/list of (width, height) for resize_mode='fixed_size', got {size}" + ) + if not all(isinstance(x, int) for x in size): + raise ValueError( + f"Size values must be integers for resize_mode='fixed_size', got {size}" + ) + + # Get list of image paths + if isinstance(folder_or_list, str): + # If folder_or_list is a string, assume it's a path to a folder + if verbose: + print(f"Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + elif isinstance(folder_or_list, list): + # If folder_or_list is a list, assume it's a list of image paths + if verbose: + print(f"Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + else: + # If folder_or_list is neither a string nor a list, raise an error + raise ValueError(f"Bad {folder_or_list=} ({type(folder_or_list)})") + + # Define supported image extensions + supported_images_extensions = [".jpg", ".jpeg", ".png"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + # First pass: Load all images and collect aspect ratios + loaded_images = [] + aspect_ratios = [] + for i, path in enumerate(folder_content): + # Skip images based on stride + if i % stride != 0: + continue + + # Check if the file has a supported image extension + if not path.lower().endswith(supported_images_extensions): + continue + + try: + if bayer_format: + # If bayer_format is True, read the image in Bayer format + color_bayer = cv2.imread(os.path.join(root, path), cv2.IMREAD_UNCHANGED) + color = cv2.cvtColor(color_bayer, cv2.COLOR_BAYER_RG2BGR) + img = PIL.Image.fromarray(color) + img = exif_transpose(img).convert("RGB") + else: + # Otherwise, read the image normally + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert( + "RGB" + ) + + W1, H1 = img.size + aspect_ratios.append(W1 / H1) + loaded_images.append((path, img, W1, H1)) + + except Exception as e: + if verbose: + print(f"Warning: Could not load {path}: {e}") + continue + + # Check if any images were loaded + if not loaded_images: + raise ValueError("No valid images found") + + # Calculate average aspect ratio and determine target size + average_aspect_ratio = sum(aspect_ratios) / len(aspect_ratios) + if verbose: + print( + f"Calculated average aspect ratio: {average_aspect_ratio:.3f} from {len(aspect_ratios)} images" + ) + + # Determine target size for all images based on resize mode + if resize_mode == "fixed_mapping": + # Resolution mappings are already compatible with their respective patch sizes + # 518 mappings are divisible by 14, 512 mappings are divisible by 16 + target_width, target_height = find_closest_aspect_ratio( + average_aspect_ratio, resolution_set + ) + target_size = (target_width, target_height) + elif resize_mode == "square": + target_size = ( + round((size // patch_size)) * patch_size, + round((size // patch_size)) * patch_size, + ) + elif resize_mode == "longest_side": + # Use average aspect ratio to determine size for all images + # Longest side should be the input size + if average_aspect_ratio >= 1: # Landscape or square + # Width is the longest side + target_size = ( + size, + round((size // patch_size) / average_aspect_ratio) * patch_size, + ) + else: # Portrait + # Height is the longest side + target_size = ( + round((size // patch_size) * average_aspect_ratio) * patch_size, + size, + ) + elif resize_mode == "fixed_size": + # Use exact size provided, aligned to patch_size + target_size = ( + (size[0] // patch_size) * patch_size, + (size[1] // patch_size) * patch_size, + ) + + if verbose: + print( + f"Using target resolution {target_size[0]}x{target_size[1]} (W x H) for all images" + ) + + # Get the image normalization function based on the norm_type + if norm_type in IMAGE_NORMALIZATION_DICT.keys(): + img_norm = IMAGE_NORMALIZATION_DICT[norm_type] + ImgNorm = tvf.Compose( + [tvf.ToTensor(), tvf.Normalize(mean=img_norm.mean, std=img_norm.std)] + ) + else: + raise ValueError( + f"Unknown image normalization type: {norm_type}. Available options: {list(IMAGE_NORMALIZATION_DICT.keys())}" + ) + + # Second pass: Resize all images to the same target size + imgs = [] + for path, img, W1, H1 in loaded_images: + # Resize and crop the image to the target size + img = crop_resize_if_necessary(img, resolution=target_size)[0] + + # Normalize image and add it to the list + W2, H2 = img.size + if verbose: + print(f" - Adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=str(len(imgs)), + data_norm_type=[norm_type], + ) + ) + + assert imgs, "No images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + + return imgs + + +def preprocess_inputs( + input_views, + resize_mode="fixed_mapping", + size=None, + norm_type="dinov2", + patch_size=14, + resolution_set=518, + verbose=False, +): + """ + Preprocess input_views by determining optimal aspect ratio and resizing all images and multi-modal inputs. + + Similar to load_images function, this function: + (a) Determines the optimal aspect ratio from all input images + (b) Resizes all images and multi-modal inputs using crop_resize_if_necessary + (c) Normalizes images according to the specified normalization type + + Args: + input_views (list): List of dictionaries containing view data. Each view can contain: + - img: Image tensor (H, W, 3) - [0, 255] or PIL Image + - intrinsics: Camera intrinsics (3, 3) + - depth_z: Depth maps (H, W) + - ray_directions: Ray directions (H, W, 3) + - camera_poses: Camera poses (4, 4) or tuple of (quats, trans) - not resized + - is_metric_scale: Boolean value - not resized + resize_mode (str): Resize mode - "fixed_mapping", "longest_side", "square", or "fixed_size". Defaults to "fixed_mapping". + size (int or tuple, optional): Required for "longest_side", "square", and "fixed_size" modes. + norm_type (str, optional): Image normalization type. See UniCeption IMAGE_NORMALIZATION_DICT keys. Defaults to "dinov2". + patch_size (int, optional): Patch size for image processing. Defaults to 14. + resolution_set (int, optional): Resolution set to use for "fixed_mapping" mode (518 or 512). Defaults to 518. + verbose (bool, optional): If True, print progress messages. Defaults to False. + + Returns: + list: List of processed view dictionaries with resized images and multi-modal inputs + """ + # Validate resize_mode and size parameter requirements + valid_resize_modes = ["fixed_mapping", "longest_side", "square", "fixed_size"] + if resize_mode not in valid_resize_modes: + raise ValueError( + f"Resize_mode must be one of {valid_resize_modes}, got '{resize_mode}'" + ) + + if resize_mode in ["longest_side", "square", "fixed_size"] and size is None: + raise ValueError(f"Size parameter is required for resize_mode='{resize_mode}'") + + # Validate size type based on resize mode + if resize_mode in ["longest_side", "square"]: + if not isinstance(size, int): + raise ValueError( + f"Size must be an int for resize_mode='{resize_mode}', got {type(size)}" + ) + elif resize_mode == "fixed_size": + if not isinstance(size, (tuple, list)) or len(size) != 2: + raise ValueError( + f"Size must be a tuple/list of (width, height) for resize_mode='fixed_size', got {size}" + ) + if not all(isinstance(x, int) for x in size): + raise ValueError( + f"Size values must be integers for resize_mode='fixed_size', got {size}" + ) + + if not input_views: + raise ValueError("input_views cannot be empty") + + # First pass: Extract all images and collect aspect ratios + aspect_ratios = [] + for view_idx, view in enumerate(input_views): + if "img" not in view: + if verbose: + print( + f"Warning: View {view_idx} has no 'img' key, skipping for aspect ratio calculation" + ) + continue + + img = view["img"] + + # Handle different image formats (no batch dimension expected) + if isinstance(img, torch.Tensor): + # Tensor format: (H, W, 3) - channel last + if img.ndim == 3 and img.shape[2] == 3: + H, W = img.shape[0], img.shape[1] + else: + raise ValueError( + f"Expected tensor shape (H, W, 3) for img in view {view_idx}, got {img.shape}" + ) + elif isinstance(img, PIL.Image.Image): + W, H = img.size + elif isinstance(img, np.ndarray): + # Array format: (H, W, 3) - channel last + if img.ndim == 3 and img.shape[2] == 3: + H, W = img.shape[0], img.shape[1] + else: + raise ValueError( + f"Expected array shape (H, W, 3) for img in view {view_idx}, got {img.shape}" + ) + else: + raise ValueError(f"Unsupported image type in view {view_idx}: {type(img)}") + + aspect_ratios.append(W / H) + + if not aspect_ratios: + raise ValueError("No valid images found in input_views") + + # Calculate average aspect ratio and determine target size + average_aspect_ratio = sum(aspect_ratios) / len(aspect_ratios) + if verbose: + print( + f"Calculated average aspect ratio: {average_aspect_ratio:.3f} from {len(aspect_ratios)} images" + ) + + # Determine target size for all images based on resize mode + if resize_mode == "fixed_mapping": + # Resolution mappings are already compatible with their respective patch sizes + target_width, target_height = find_closest_aspect_ratio( + average_aspect_ratio, resolution_set + ) + target_size = (target_width, target_height) + elif resize_mode == "square": + target_size = ( + round((size // patch_size)) * patch_size, + round((size // patch_size)) * patch_size, + ) + elif resize_mode == "longest_side": + # Use average aspect ratio to determine size for all images + if average_aspect_ratio >= 1: # Landscape or square + target_size = ( + size, + round((size // patch_size) / average_aspect_ratio) * patch_size, + ) + else: # Portrait + target_size = ( + round((size // patch_size) * average_aspect_ratio) * patch_size, + size, + ) + elif resize_mode == "fixed_size": + # Use exact size provided, aligned to patch_size + target_size = ( + (size[0] // patch_size) * patch_size, + (size[1] // patch_size) * patch_size, + ) + + if verbose: + print( + f"Using target resolution {target_size[0]}x{target_size[1]} (W x H) for all views" + ) + + # Get the image normalization function based on the norm_type + if norm_type in IMAGE_NORMALIZATION_DICT.keys(): + img_norm = IMAGE_NORMALIZATION_DICT[norm_type] + ImgNorm = tvf.Compose( + [tvf.ToTensor(), tvf.Normalize(mean=img_norm.mean, std=img_norm.std)] + ) + else: + raise ValueError( + f"Unknown image normalization type: {norm_type}. Available options: {list(IMAGE_NORMALIZATION_DICT.keys())}" + ) + + # Helper function to convert tensor/array to PIL Image + def to_pil_image(img, view_idx): + """Convert tensor or array to PIL Image for processing.""" + if isinstance(img, torch.Tensor): + # Convert tensor to PIL Image for processing - expect (H, W, 3) + if img.ndim != 3 or img.shape[2] != 3: + raise ValueError( + f"Expected tensor shape (H, W, 3) for img in view {view_idx}, got {img.shape}" + ) + # Only multiply with 255 if the image range is within [0, 1] + if img.max() <= 1.0: + img = (img * 255).clamp(0, 255).byte().cpu().numpy() + else: + img = img.clamp(0, 255).byte().cpu().numpy() + return PIL.Image.fromarray(img) + elif isinstance(img, np.ndarray): + # Expect (H, W, 3) format + if img.ndim != 3 or img.shape[2] != 3: + raise ValueError( + f"Expected array shape (H, W, 3) for img in view {view_idx}, got {img.shape}" + ) + if img.dtype != np.uint8: + img = (img * 255).clip(0, 255).astype(np.uint8) + return PIL.Image.fromarray(img) + elif isinstance(img, PIL.Image.Image): + return img + else: + raise ValueError(f"Unsupported image type in view {view_idx}: {type(img)}") + + # Helper function to convert tensor to numpy array + def to_numpy(data, expected_shape, name, view_idx): + """Convert tensor to numpy array and validate shape.""" + if isinstance(data, torch.Tensor): + data = data.cpu().numpy() + + if not isinstance(data, np.ndarray): + raise ValueError( + f"Expected tensor or array for {name} in view {view_idx}, got {type(data)}" + ) + + if data.shape != expected_shape and expected_shape is not None: + raise ValueError( + f"Expected shape {expected_shape} for {name} in view {view_idx}, got {data.shape}" + ) + + return data + + # Second pass: Resize all images and multi-modal inputs + processed_views = [] + for view_idx, view in enumerate(input_views): + # Convert image to PIL format + if "img" not in view: + raise ValueError(f"View {view_idx} missing required 'img' key") + + img = to_pil_image(view["img"], view_idx) + + # Prepare inputs for crop_resize_if_necessary + depthmap = None + intrinsics = None + + # Handle depth_z + if "depth_z" in view: + depthmap = to_numpy(view["depth_z"], None, "depth_z", view_idx) + if depthmap.ndim != 2: + raise ValueError( + f"Expected shape (H, W) for depth_z in view {view_idx}, got {depthmap.shape}" + ) + + # Enforce that only one of intrinsics and ray_directions is provided + has_intrinsics = "intrinsics" in view + has_ray_directions = "ray_directions" in view + + if has_intrinsics and has_ray_directions: + raise ValueError( + f"View {view_idx} cannot have both 'intrinsics' and 'ray_directions'. " + "Please provide only one as they are redundant (ray_directions can be used to recover intrinsics)." + ) + + # Handle intrinsics + if has_intrinsics: + intrinsics = to_numpy(view["intrinsics"], (3, 3), "intrinsics", view_idx) + + # Handle ray_directions by recovering intrinsics from them + if has_ray_directions: + ray_dirs = to_numpy( + view["ray_directions"], None, "ray_directions", view_idx + ) + if ray_dirs.ndim != 3 or ray_dirs.shape[2] != 3: + raise ValueError( + f"Expected shape (H, W, 3) for ray_directions in view {view_idx}, got {ray_dirs.shape}" + ) + + # Convert ray directions to torch tensor for the geometry function + ray_dirs_torch = torch.from_numpy(ray_dirs) + + # Recover intrinsics from ray directions + recovered_intrinsics = recover_pinhole_intrinsics_from_ray_directions( + ray_dirs_torch + ) + recovered_intrinsics = recovered_intrinsics.cpu().numpy() + intrinsics = recovered_intrinsics + + # Process all inputs with a single call to crop_resize_if_necessary + results = crop_resize_if_necessary( + image=img, + resolution=target_size, + depthmap=depthmap, + intrinsics=intrinsics, + ) + + # Unpack results based on what was provided + processed_view = {} + result_idx = 0 + + # Image is always first - normalize it after resizing + resized_img = results[result_idx] + processed_view["img"] = ImgNorm(resized_img)[ + None + ] # Add batch dimension like load_images + processed_view["data_norm_type"] = [norm_type] # Add normalization type + result_idx += 1 + + # Depth is next if provided - add batch dimension + if depthmap is not None: + processed_view["depth_z"] = torch.from_numpy(results[result_idx])[None] + result_idx += 1 + + # Intrinsics is next if provided - add batch dimension + if intrinsics is not None: + processed_view["intrinsics"] = torch.from_numpy(results[result_idx])[None] + result_idx += 1 + + # Handle camera_poses with batch dimension if present + if "camera_poses" in view: + camera_poses = view["camera_poses"] + if isinstance(camera_poses, tuple): + # Tuple format (quats, trans) - add batch dimension to both components + quats, trans = camera_poses + if isinstance(quats, torch.Tensor): + quats_batched = quats[None] + elif isinstance(quats, np.ndarray): + quats_batched = torch.from_numpy(quats)[None] + else: + quats_batched = torch.tensor(quats)[None] + if isinstance(trans, torch.Tensor): + trans_batched = trans[None] + elif isinstance(trans, np.ndarray): + trans_batched = torch.from_numpy(trans)[None] + else: + trans_batched = torch.tensor(trans)[None] + processed_view["camera_poses"] = (quats_batched, trans_batched) + else: + # Matrix format - add batch dimension + if isinstance(camera_poses, torch.Tensor): + processed_view["camera_poses"] = camera_poses[None] + elif isinstance(camera_poses, np.ndarray): + processed_view["camera_poses"] = torch.from_numpy(camera_poses)[ + None + ] + else: + raise ValueError( + f"Unsupported camera_poses format: {type(camera_poses)}. Expected tuple (quats, trans) or matrix (tensor/array)." + ) + + # Copy over any other keys that don't need resizing or batch dimensions + for key, value in view.items(): + if key not in [ + "img", + "depth_z", + "intrinsics", + "ray_directions", + "camera_poses", + ]: + processed_view[key] = value + + processed_views.append(processed_view) + + if verbose: + print(f"Processed view {view_idx} with keys: {list(processed_view.keys())}") + + if verbose: + print(f"Successfully processed {len(processed_views)} views") + + return processed_views diff --git a/mapanything/utils/inference.py b/mapanything/utils/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..1c570d91131c452d233cb5b028965c15e7cacf57 --- /dev/null +++ b/mapanything/utils/inference.py @@ -0,0 +1,480 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference utilities. +""" + +import warnings +from typing import Any, Dict, List + +import numpy as np +import torch + +from mapanything.utils.geometry import ( + depth_edge, + get_rays_in_camera_frame, + normals_edge, + points_to_normals, + quaternion_to_rotation_matrix, + recover_pinhole_intrinsics_from_ray_directions, + rotation_matrix_to_quaternion, +) +from mapanything.utils.image import rgb + +# Hard constraints - exactly what users can provide +ALLOWED_VIEW_KEYS = { + "img", # Required - input images + "data_norm_type", # Required - normalization type of the input images + "depth_z", # Optional - Z depth maps + "ray_directions", # Optional - ray directions in camera frame + "intrinsics", # Optional - pinhole camera intrinsics (conflicts with ray_directions) + "camera_poses", # Optional - camera poses + "is_metric_scale", # Optional - whether inputs are metric scale + "true_shape", # Optional - original image shape + "idx", # Optional - index of the view + "instance", # Optional - instance info of the view +} + +REQUIRED_KEYS = {"img", "data_norm_type"} + +# Define conflicting keys that cannot be used together +CONFLICTING_KEYS = [ + ("intrinsics", "ray_directions") # Both represent camera projection +] + + +def loss_of_one_batch_multi_view( + batch, + model, + criterion, + device, + use_amp=False, + amp_dtype="bf16", + ret=None, + ignore_keys=None, +): + """ + Calculate loss for a batch with multiple views. + + Args: + batch (list): List of view dictionaries containing input data. + model (torch.nn.Module): Model to run inference with. + criterion (callable, optional): Loss function to compute the loss. + device (torch.device): Device to run the computation on. + use_amp (bool, optional): Whether to use automatic mixed precision. Defaults to False. + amp_dtype (str, optional): Floating point type to use for automatic mixed precision. Options: ["fp32", "fp16", "bf16"]. Defaults to "bf16". + ret (str, optional): If provided, return only the specified key from the result dictionary. + ignore_keys (set, optional): Set of keys to ignore when moving tensors to device. + Defaults to {"dataset", "label", "instance", + "idx", "true_shape", "rng", "data_norm_type"}. + + Returns: + dict or Any: If ret is None, returns a dictionary containing views, predictions, and loss. + Otherwise, returns the value associated with the ret key. + """ + # Move necessary tensors to device + if ignore_keys is None: + ignore_keys = set( + [ + "depthmap", + "dataset", + "label", + "instance", + "idx", + "true_shape", + "rng", + "data_norm_type", + ] + ) + for view in batch: + for name in view.keys(): + if name in ignore_keys: + continue + view[name] = view[name].to(device, non_blocking=True) + + # Determine the mixed precision floating point type + if use_amp: + if amp_dtype == "fp16": + amp_dtype = torch.float16 + elif amp_dtype == "bf16": + if torch.cuda.is_bf16_supported(): + amp_dtype = torch.bfloat16 + else: + warnings.warn( + "bf16 is not supported on this device. Using fp16 instead." + ) + amp_dtype = torch.float16 + elif amp_dtype == "fp32": + amp_dtype = torch.float32 + else: + amp_dtype = torch.float32 + + # Run model and compute loss + with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype): + preds = model(batch) + with torch.autocast("cuda", enabled=False): + loss = criterion(batch, preds) if criterion is not None else None + + result = {f"view{i + 1}": view for i, view in enumerate(batch)} + result.update({f"pred{i + 1}": pred for i, pred in enumerate(preds)}) + result["loss"] = loss + + return result[ret] if ret else result + + +def validate_input_views_for_inference( + views: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """ + Strict validation and preprocessing of input views. + + Args: + views: List of view dictionaries + + Returns: + Validated and preprocessed views + + Raises: + ValueError: For invalid keys, missing required keys, conflicting inputs, or invalid camera pose constraints + """ + # Ensure input is not empty + if not views: + raise ValueError("At least one view must be provided") + + # Track which views have camera poses + views_with_poses = [] + + # Validate each view + for view_idx, view in enumerate(views): + # Check for invalid keys + provided_keys = set(view.keys()) + invalid_keys = provided_keys - ALLOWED_VIEW_KEYS + if invalid_keys: + raise ValueError( + f"View {view_idx} contains invalid keys: {invalid_keys}. " + f"Allowed keys are: {sorted(ALLOWED_VIEW_KEYS)}" + ) + + # Check for missing required keys + missing_keys = REQUIRED_KEYS - provided_keys + if missing_keys: + raise ValueError(f"View {view_idx} missing required keys: {missing_keys}") + + # Check for conflicting keys + for conflict_set in CONFLICTING_KEYS: + present_conflicts = [key for key in conflict_set if key in provided_keys] + if len(present_conflicts) > 1: + raise ValueError( + f"View {view_idx} contains conflicting keys: {present_conflicts}. " + f"Only one of {conflict_set} can be provided at a time." + ) + + # Check depth constraint: If depth is provided, intrinsics or ray_directions must also be provided + if "depth_z" in provided_keys: + if ( + "intrinsics" not in provided_keys + and "ray_directions" not in provided_keys + ): + raise ValueError( + f"View {view_idx} depth constraint violation: If 'depth_z' is provided, " + f"then 'intrinsics' or 'ray_directions' must also be provided. " + f"Z Depth values require camera calibration information to be meaningful for an image." + ) + + # Track views with camera poses + if "camera_poses" in provided_keys: + views_with_poses.append(view_idx) + + # Cross-view constraint: If any view has camera_poses, view 0 must have them too + if views_with_poses and 0 not in views_with_poses: + raise ValueError( + f"Camera pose constraint violation: Views {views_with_poses} have camera_poses, " + f"but view 0 (reference view) does not. When using camera_poses, the first view " + f"must also provide camera_poses to serve as the reference frame." + ) + + return views + + +def preprocess_input_views_for_inference( + views: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """ + Pre-process input views to match the expected internal input format. + + The following steps are performed: + 1. Convert intrinsics to ray directions when required. If ray directions are already provided, unit normalize them. + 2. Convert depth_z to depth_along_ray + 3. Convert camera_poses to the expected input keys (camera_pose_quats and camera_pose_trans) + 4. Default is_metric_scale to True when not provided + + Args: + views: List of view dictionaries + + Returns: + Preprocessed views with consistent internal format + """ + processed_views = [] + + for view_idx, view in enumerate(views): + # Copy the view dictionary to avoid modifying the original input + processed_view = dict(view) + + # Step 1: Convert intrinsics to ray_directions when required. If ray_directions are provided, unit normalize them. + if "intrinsics" in view: + images = view["img"] + height, width = images.shape[-2:] + intrinsics = view["intrinsics"] + _, ray_directions = get_rays_in_camera_frame( + intrinsics=intrinsics, + height=height, + width=width, + normalize_to_unit_sphere=True, + ) + processed_view["ray_directions"] = ray_directions + del processed_view["intrinsics"] + elif "ray_directions" in view: + ray_directions = view["ray_directions"] + ray_norm = torch.norm(ray_directions, dim=-1, keepdim=True) + processed_view["ray_directions"] = ray_directions / (ray_norm + 1e-8) + + # Step 2: Convert depth_z to depth_along_ray + if "depth_z" in view: + depth_z = view["depth_z"] + ray_directions = processed_view["ray_directions"] + ray_directions_unit_plane = ray_directions / ray_directions[..., 2:3] + pts3d_cam = depth_z[..., None] * ray_directions_unit_plane + depth_along_ray = torch.norm(pts3d_cam, dim=-1, keepdim=True) + processed_view["depth_along_ray"] = depth_along_ray + del processed_view["depth_z"] + + # Step 3: Convert camera_poses to expected input keys + if "camera_poses" in view: + camera_poses = view["camera_poses"] + if isinstance(camera_poses, tuple) and len(camera_poses) == 2: + quats, trans = camera_poses + processed_view["camera_pose_quats"] = quats + processed_view["camera_pose_trans"] = trans + elif torch.is_tensor(camera_poses) and camera_poses.shape[-2:] == (4, 4): + rotation_matrices = camera_poses[:, :3, :3] + translation_vectors = camera_poses[:, :3, 3] + quats = rotation_matrix_to_quaternion(rotation_matrices) + processed_view["camera_pose_quats"] = quats + processed_view["camera_pose_trans"] = translation_vectors + else: + raise ValueError( + f"View {view_idx}: camera_poses must be either a tuple of (quats, trans) " + f"or a tensor of (B, 4, 4) transformation matrices." + ) + del processed_view["camera_poses"] + + # Step 4: Default is_metric_scale to True when not provided + if "is_metric_scale" not in processed_view: + # Get batch size from the image tensor + batch_size = view["img"].shape[0] + # Default to True for all samples in the batch + processed_view["is_metric_scale"] = torch.ones( + batch_size, dtype=torch.bool, device=view["img"].device + ) + + # Rename keys to match expected model input format + if "ray_directions" in processed_view: + processed_view["ray_directions_cam"] = processed_view["ray_directions"] + del processed_view["ray_directions"] + + # Append the processed view to the list + processed_views.append(processed_view) + + return processed_views + + +def postprocess_model_outputs_for_inference( + raw_outputs: List[Dict[str, torch.Tensor]], + input_views: List[Dict[str, Any]], + apply_mask: bool = True, + mask_edges: bool = True, + edge_normal_threshold: float = 5.0, + edge_depth_threshold: float = 0.03, + apply_confidence_mask: bool = False, + confidence_percentile: float = 10, +) -> List[Dict[str, torch.Tensor]]: + """ + Post-process raw model outputs by copying raw outputs and adding essential derived fields. + + This function simplifies the raw model outputs by: + 1. Copying all raw outputs as-is + 2. Adding denormalized images (img_no_norm) + 3. Adding Z depth (depth_z) from camera frame points + 4. Recovering pinhole camera intrinsics from ray directions + 5. Adding camera pose matrices (camera_poses) if pose data is available + 6. Applying mask to dense geometry outputs if requested (supports edge masking and confidence masking) + + Args: + raw_outputs: List of raw model output dictionaries, one per view + input_views: List of original input view dictionaries, one per view + apply_mask: Whether to apply non-ambiguous mask to dense outputs. Defaults to True. + mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True. + apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False. + confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10. + + Returns: + List of processed output dictionaries containing: + - All original raw outputs (after masking dense geometry outputs if requested) + - 'img_no_norm': Denormalized RGB images (B, H, W, 3) + - 'depth_z': Z depth from camera frame (B, H, W, 1) if points in camera frame available + - 'intrinsics': Recovered pinhole camera intrinsics (B, 3, 3) if ray directions available + - 'camera_poses': 4x4 pose matrices (B, 4, 4) if pose data available + - 'mask': comprehensive mask for dense geometry outputs (B, H, W, 1) if requested + + """ + processed_outputs = [] + + for view_idx, (raw_output, original_view) in enumerate( + zip(raw_outputs, input_views) + ): + # Start by copying all raw outputs + processed_output = dict(raw_output) + + # 1. Add denormalized images + img = original_view["img"] # Shape: (B, 3, H, W) + data_norm_type = original_view["data_norm_type"][0] + img_hwc = rgb(img, data_norm_type) + + # Convert numpy back to torch if needed (rgb returns numpy) + if isinstance(img_hwc, np.ndarray): + img_hwc = torch.from_numpy(img_hwc).to(img.device) + + processed_output["img_no_norm"] = img_hwc + + # 2. Add Z depth if we have camera frame points + if "pts3d_cam" in processed_output: + processed_output["depth_z"] = processed_output["pts3d_cam"][..., 2:3] + + # 3. Recover pinhole camera intrinsics from ray directions if available + if "ray_directions" in processed_output: + intrinsics = recover_pinhole_intrinsics_from_ray_directions( + processed_output["ray_directions"] + ) + processed_output["intrinsics"] = intrinsics + + # 4. Add camera pose matrices if both translation and quaternions are available + if "cam_trans" in processed_output and "cam_quats" in processed_output: + cam_trans = processed_output["cam_trans"] # (B, 3) + cam_quats = processed_output["cam_quats"] # (B, 4) + batch_size = cam_trans.shape[0] + + # Convert quaternions to rotation matrices + rotation_matrices = quaternion_to_rotation_matrix(cam_quats) # (B, 3, 3) + + # Create 4x4 pose matrices + pose_matrices = ( + torch.eye(4, device=img.device).unsqueeze(0).repeat(batch_size, 1, 1) + ) + pose_matrices[:, :3, :3] = rotation_matrices + pose_matrices[:, :3, 3] = cam_trans + + processed_output["camera_poses"] = pose_matrices # (B, 4, 4) + + # 5. Apply comprehensive mask to dense geometry outputs if requested + if apply_mask: + final_mask = None + + # Start with non-ambiguous mask if available + if "non_ambiguous_mask" in processed_output: + non_ambiguous_mask = ( + processed_output["non_ambiguous_mask"].cpu().numpy() + ) # (B, H, W) + final_mask = non_ambiguous_mask + + # Apply confidence mask if requested and available + if apply_confidence_mask and "conf" in processed_output: + confidences = processed_output["conf"].cpu() # (B, H, W) + # Compute percentile threshold for each batch element + batch_size = confidences.shape[0] + conf_mask = torch.zeros_like(confidences, dtype=torch.bool) + percentile_threshold = ( + torch.quantile( + confidences.reshape(batch_size, -1), + confidence_percentile / 100.0, + dim=1, + ) + .unsqueeze(-1) + .unsqueeze(-1) + ) # Shape: (B, 1, 1) + + # Compute mask for each batch element + conf_mask = confidences > percentile_threshold + conf_mask = conf_mask.numpy() + + if final_mask is not None: + final_mask = final_mask & conf_mask + else: + final_mask = conf_mask + + # Apply edge mask if requested and we have the required data + if mask_edges and final_mask is not None and "pts3d" in processed_output: + # Get 3D points for edge computation + pred_pts3d = processed_output["pts3d"].cpu().numpy() # (B, H, W, 3) + batch_size, height, width = final_mask.shape + + edge_masks = [] + for b in range(batch_size): + batch_final_mask = final_mask[b] # (H, W) + batch_pts3d = pred_pts3d[b] # (H, W, 3) + + if batch_final_mask.any(): # Only compute if we have valid points + # Compute normals and normal-based edge mask + normals, normals_mask = points_to_normals( + batch_pts3d, mask=batch_final_mask + ) + normal_edges = normals_edge( + normals, tol=edge_normal_threshold, mask=normals_mask + ) + + # Compute depth-based edge mask + depth_z = ( + processed_output["depth_z"][b].squeeze(-1).cpu().numpy() + ) + depth_edges = depth_edge( + depth_z, rtol=edge_depth_threshold, mask=batch_final_mask + ) + + # Combine both edge types + edge_mask = ~(depth_edges & normal_edges) + edge_masks.append(edge_mask) + else: + # No valid points, keep all as invalid + edge_masks.append(np.zeros_like(batch_final_mask, dtype=bool)) + + # Stack batch edge masks and combine with final mask + edge_mask = np.stack(edge_masks, axis=0) # (B, H, W) + final_mask = final_mask & edge_mask + + # Apply final mask to dense geometry outputs if we have a mask + if final_mask is not None: + # Convert mask to torch tensor + final_mask_torch = torch.from_numpy(final_mask).to( + processed_output["pts3d"].device + ) + final_mask_torch = final_mask_torch.unsqueeze(-1) # (B, H, W, 1) + + # Apply mask to dense geometry outputs (zero out invalid regions) + dense_geometry_keys = [ + "pts3d", + "pts3d_cam", + "depth_along_ray", + "depth_z", + ] + for key in dense_geometry_keys: + if key in processed_output: + processed_output[key] = processed_output[key] * final_mask_torch + + # Add mask to processed output + processed_output["mask"] = final_mask_torch + + processed_outputs.append(processed_output) + + return processed_outputs diff --git a/mapanything/utils/metrics.py b/mapanything/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..250cc8fec571b4e2fce15381bb96e0b3527108ff --- /dev/null +++ b/mapanything/utils/metrics.py @@ -0,0 +1,509 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utils for Metrics +Source for Pose AUC Metrics: VGGT +""" + +import math + +import numpy as np +import torch +import torch.nn.functional as F + + +def l2_distance_of_unit_quats_to_angular_error(l2_distance): + """ + Converts a given L2 distance (for unit quaternions) to the angular error in degrees. + For two quaternions differing by an angle θ the relationship is: + L2 distance = 2 * sin(θ/4) + Hence, the angular error in degrees is computed as: + 4 * asin(l2_distance / 2) * (180/π) + + Args: + l2_distance: L2 distance between two unit quaternions (torch.Tensor, shape: (N,)) + Returns: + angular_error_degrees: Angular error in degrees (torch.Tensor, shape: (N,)) + """ + angular_error_radians = 4 * torch.asin(l2_distance / 2) + angular_error_degrees = angular_error_radians * 180.0 / math.pi + + return angular_error_degrees + + +def l2_distance_of_unit_ray_directions_to_angular_error(l2_distance): + """ + Converts a given L2 distance (for unit ray directions) to the angular error in degrees. + For two unit ray directions differing by an angle θ the relationship is: + L2 distance = 2 * sin(θ/2) + Hence, the angular error in degrees is computed as: + 2 * asin(l2_distance / 2) * (180/π) + + Args: + l2_distance: L2 distance between two unit ray directions (torch.Tensor, shape: (N,)) + Returns: + angular_error_degrees: Angular error in degrees (torch.Tensor, shape: (N,)) + """ + angular_error_radians = 2 * torch.asin(l2_distance / 2) + angular_error_degrees = angular_error_radians * 180.0 / math.pi + + return angular_error_degrees + + +def valid_mean(arr, mask, axis=None, keepdims=np._NoValue): + """Compute mean of elements across given dimensions of an array, considering only valid elements. + + Args: + arr: The array to compute the mean. + mask: Array with numerical or boolean values for element weights or validity. For bool, False means invalid. + axis: Dimensions to reduce. + keepdims: If true, retains reduced dimensions with length 1. + + Returns: + Mean array/scalar and a valid array/scalar that indicates where the mean could be computed successfully. + """ + + mask = mask.astype(arr.dtype) if mask.dtype == bool else mask + num_valid = np.sum(mask, axis=axis, keepdims=keepdims) + masked_arr = arr * mask + masked_arr_sum = np.sum(masked_arr, axis=axis, keepdims=keepdims) + + with np.errstate(divide="ignore", invalid="ignore"): + valid_mean = masked_arr_sum / num_valid + is_valid = np.isfinite(valid_mean) + valid_mean = np.nan_to_num(valid_mean, nan=0, posinf=0, neginf=0) + + return valid_mean, is_valid + + +def thresh_inliers(gt, pred, thresh=1.03, mask=None, output_scaling_factor=1.0): + """Computes the inlier (=error within a threshold) ratio for a predicted and ground truth dense map of size H x W x C. + + Args: + gt: Ground truth depth map as numpy array of shape HxW. Negative or 0 values are invalid and ignored. + pred: Predicted depth map as numpy array of shape HxW. + thresh: Threshold for the relative difference between the prediction and ground truth. Default: 1.03 + mask: Array of shape HxW with boolean values to indicate validity. For bool, False means invalid. Default: None + output_scaling_factor: Scaling factor that is applied after computing the metrics (e.g. to get [%]). Default: 1 + + Returns: + Scalar that indicates the inlier ratio. Scalar is np.nan if the result is invalid. + """ + # Compute the norms + gt_norm = np.linalg.norm(gt, axis=-1) + pred_norm = np.linalg.norm(pred, axis=-1) + + gt_norm_valid = (gt_norm) > 0 + if mask is not None: + combined_mask = mask & gt_norm_valid + else: + combined_mask = gt_norm_valid + + with np.errstate(divide="ignore", invalid="ignore"): + rel_1 = np.nan_to_num( + gt_norm / pred_norm, nan=thresh + 1, posinf=thresh + 1, neginf=thresh + 1 + ) # pred=0 should be an outlier + rel_2 = np.nan_to_num( + pred_norm / gt_norm, nan=0, posinf=0, neginf=0 + ) # gt=0 is masked out anyways + + max_rel = np.maximum(rel_1, rel_2) + inliers = ((0 < max_rel) & (max_rel < thresh)).astype( + np.float32 + ) # 1 for inliers, 0 for outliers + + inlier_ratio, valid = valid_mean(inliers, combined_mask) + + inlier_ratio = inlier_ratio * output_scaling_factor + inlier_ratio = inlier_ratio if valid else np.nan + + return inlier_ratio + + +def m_rel_ae(gt, pred, mask=None, output_scaling_factor=1.0): + """Computes the mean-relative-absolute-error for a predicted and ground truth dense map of size HxWxC. + + Args: + gt: Ground truth map as numpy array of shape H x W x C. + pred: Predicted map as numpy array of shape H x W x C. + mask: Array of shape HxW with boolean values to indicate validity. For bool, False means invalid. Default: None + output_scaling_factor: Scaling factor that is applied after computing the metrics (e.g. to get [%]). Default: 1 + + Returns: + Scalar that indicates the mean-relative-absolute-error. Scalar is np.nan if the result is invalid. + """ + error_norm = np.linalg.norm(pred - gt, axis=-1) + gt_norm = np.linalg.norm(gt, axis=-1) + + gt_norm_valid = (gt_norm) > 0 + if mask is not None: + combined_mask = mask & gt_norm_valid + else: + combined_mask = gt_norm_valid + + with np.errstate(divide="ignore", invalid="ignore"): + rel_ae = np.nan_to_num(error_norm / gt_norm, nan=0, posinf=0, neginf=0) + + m_rel_ae, valid = valid_mean(rel_ae, combined_mask) + + m_rel_ae = m_rel_ae * output_scaling_factor + m_rel_ae = m_rel_ae if valid else np.nan + + return m_rel_ae + + +def align(model, data): + """Align two trajectories using the method of Horn (closed-form). + + Args: + model -- first trajectory (3xn) + data -- second trajectory (3xn) + + Returns: + rot -- rotation matrix (3x3) + trans -- translation vector (3x1) + trans_error -- translational error per point (1xn) + + """ + np.set_printoptions(precision=3, suppress=True) + model_zerocentered = model - model.mean(1).reshape((3, -1)) + data_zerocentered = data - data.mean(1).reshape((3, -1)) + + W = np.zeros((3, 3)) + for column in range(model.shape[1]): + W += np.outer(model_zerocentered[:, column], data_zerocentered[:, column]) + U, d, Vh = np.linalg.linalg.svd(W.transpose()) + S = np.matrix(np.identity(3)) + if np.linalg.det(U) * np.linalg.det(Vh) < 0: + S[2, 2] = -1 + rot = U * S * Vh + trans = data.mean(1).reshape((3, -1)) - rot * model.mean(1).reshape((3, -1)) + + model_aligned = rot * model + trans + alignment_error = model_aligned - data + + trans_error = np.sqrt(np.sum(np.multiply(alignment_error, alignment_error), 0)).A[0] + + return rot, trans, trans_error + + +def evaluate_ate(gt_traj, est_traj): + """ + Input : + gt_traj: list of 4x4 matrices + est_traj: list of 4x4 matrices + len(gt_traj) == len(est_traj) + """ + gt_traj_pts = [gt_traj[idx][:3, 3] for idx in range(len(gt_traj))] + est_traj_pts = [est_traj[idx][:3, 3] for idx in range(len(est_traj))] + + gt_traj_pts = torch.stack(gt_traj_pts).detach().cpu().numpy().T + est_traj_pts = torch.stack(est_traj_pts).detach().cpu().numpy().T + + _, _, trans_error = align(gt_traj_pts, est_traj_pts) + + avg_trans_error = trans_error.mean() + + return avg_trans_error + + +def build_pair_index(N, B=1): + """ + Build indices for all possible pairs of frames. + + Args: + N: Number of frames + B: Batch size + + Returns: + i1, i2: Indices for all possible pairs + """ + i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1) + i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]] + return i1, i2 + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) # pylint: disable=not-callable + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15): + """ + Calculate rotation angle error between ground truth and predicted rotations. + + Args: + rot_gt: Ground truth rotation matrices + rot_pred: Predicted rotation matrices + batch_size: Batch size for reshaping the result + eps: Small value to avoid numerical issues + + Returns: + Rotation angle error in degrees + """ + q_pred = mat_to_quat(rot_pred) + q_gt = mat_to_quat(rot_gt) + + loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps) + err_q = torch.arccos(1 - 2 * loss_q) + + rel_rangle_deg = err_q * 180 / np.pi + + if batch_size is not None: + rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1) + + return rel_rangle_deg + + +def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True): + """ + Calculate translation angle error between ground truth and predicted translations. + + Args: + tvec_gt: Ground truth translation vectors + tvec_pred: Predicted translation vectors + batch_size: Batch size for reshaping the result + ambiguity: Whether to handle direction ambiguity + + Returns: + Translation angle error in degrees + """ + rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred) + rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi + + if ambiguity: + rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs()) + + if batch_size is not None: + rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1) + + return rel_tangle_deg + + +def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6): + """ + Normalize the translation vectors and compute the angle between them. + + Args: + t_gt: Ground truth translation vectors + t: Predicted translation vectors + eps: Small value to avoid division by zero + default_err: Default error value for invalid cases + + Returns: + Angular error between translation vectors in radians + """ + t_norm = torch.norm(t, dim=1, keepdim=True) + t = t / (t_norm + eps) + + t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True) + t_gt = t_gt / (t_gt_norm + eps) + + loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps) + err_t = torch.acos(torch.sqrt(1 - loss_t)) + + err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err + return err_t + + +def calculate_auc_np(r_error, t_error, max_threshold=30): + """ + Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy. + + Args: + r_error: numpy array representing R error values (Degree) + t_error: numpy array representing T error values (Degree) + max_threshold: Maximum threshold value for binning the histogram + + Returns: + AUC value and the normalized histogram + """ + error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1) + max_errors = np.max(error_matrix, axis=1) + bins = np.arange(max_threshold + 1) + histogram, _ = np.histogram(max_errors, bins=bins) + num_pairs = float(len(max_errors)) + normalized_histogram = histogram.astype(float) / num_pairs + return np.mean(np.cumsum(normalized_histogram)), normalized_histogram + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + # Check if se3 is a numpy array or a torch tensor + is_numpy = isinstance(se3, np.ndarray) + + # Validate shapes + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + # Extract R and T if not provided + if R is None: + R = se3[:, :3, :3] # (N,3,3) + if T is None: + T = se3[:, :3, 3:] # (N,3,1) + + # Transpose R + if is_numpy: + # Compute the transpose of the rotation for NumPy + R_transposed = np.transpose(R, (0, 2, 1)) + # -R^T t for NumPy + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.transpose(1, 2) # (N,3,3) + top_right = -torch.bmm(R_transposed, T) # (N,3,1) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix + + +def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames): + """ + Compute rotation and translation errors between predicted and ground truth poses. + + Args: + pred_se3: Predicted SE(3) transformations + gt_se3: Ground truth SE(3) transformations + num_frames: Number of frames + + Returns: + Rotation and translation angle errors in degrees + """ + pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames) + + # Compute relative camera poses between pairs + # We use closed_form_inverse to avoid potential numerical loss by torch.inverse() + relative_pose_gt = closed_form_inverse_se3(gt_se3[pair_idx_i1]).bmm( + gt_se3[pair_idx_i2] + ) + relative_pose_pred = closed_form_inverse_se3(pred_se3[pair_idx_i1]).bmm( + pred_se3[pair_idx_i2] + ) + + # Compute the difference in rotation and translation + rel_rangle_deg = rotation_angle( + relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3] + ) + rel_tangle_deg = translation_angle( + relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3] + ) + + return rel_rangle_deg, rel_tangle_deg diff --git a/mapanything/utils/misc.py b/mapanything/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e99b15d72af70886a0ab29d09f791ea7b1d04b1f --- /dev/null +++ b/mapanything/utils/misc.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Miscellaneous utility functions. +""" + +import logging +import os +import random + +import numpy as np +import torch + + +class StreamToLogger: + """ + A class that redirects stream writes to a logger. + + This class can be used to redirect stdout or stderr to a logger + by implementing a file-like interface with write and flush methods. + + Parameters: + - logger: A logger instance that will receive the log messages + - log_level: The logging level to use (default: logging.INFO) + """ + + def __init__(self, logger, log_level=logging.INFO): + self.logger = logger + self.log_level = log_level + self.linebuf = "" + + def write(self, buf): + """ + Write the buffer content to the logger. + + Parameters: + - buf: The string buffer to write + """ + for line in buf.rstrip().splitlines(): + self.logger.log(self.log_level, line.rstrip()) + + def flush(self): + """ + Flush method to comply with file-like object interface. + This method is required but does nothing in this implementation. + """ + pass + + +def seed_everything(seed: int = 42): + """ + Set the `seed` value for torch and numpy seeds. Also turns on + deterministic execution for cudnn. + + Parameters: + - seed: A hashable seed value + """ + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + print(f"Seed set to: {seed}") + + +def invalid_to_nans(arr, valid_mask, ndim=999): + """ + Replace invalid values in an array with NaN values based on a validity mask. + + Parameters: + - arr: Input array (typically a PyTorch tensor) + - valid_mask: Boolean mask indicating valid elements (True) and invalid elements (False) + - ndim: Maximum number of dimensions to keep; flattens dimensions if arr.ndim > ndim + + Returns: + - Modified array with invalid values replaced by NaN + """ + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = float("nan") + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr + + +def invalid_to_zeros(arr, valid_mask, ndim=999): + """ + Replace invalid values in an array with zeros based on a validity mask. + + Parameters: + - arr: Input array (typically a PyTorch tensor) + - valid_mask: Boolean mask indicating valid elements (True) and invalid elements (False) + - ndim: Maximum number of dimensions to keep; flattens dimensions if arr.ndim > ndim + + Returns: + - Tuple containing: + - Modified array with invalid values replaced by zeros + - nnz: Number of non-zero (valid) elements per sample in the batch + """ + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = 0 + nnz = valid_mask.view(len(valid_mask), -1).sum(1) + else: + nnz = ( + arr[..., 0].numel() // len(arr) if len(arr) else 0 + ) # Number of pixels per image + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr, nnz diff --git a/mapanything/utils/parallel.py b/mapanything/utils/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..7738fc44ec4e381dbcc97786e2428e01045d5553 --- /dev/null +++ b/mapanything/utils/parallel.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utility functions for multiprocessing +""" + +import os +from multiprocessing.dummy import Pool as ThreadPool + +import torch +from torch.multiprocessing import Pool as TorchPool, set_start_method +from tqdm import tqdm + + +def cpu_count(): + """ + Returns the number of available CPUs for the python process + """ + return len(os.sched_getaffinity(0)) + + +def parallel_threads( + function, + args, + workers=0, + star_args=False, + kw_args=False, + front_num=1, + Pool=ThreadPool, + ordered_res=True, + **tqdm_kw, +): + """tqdm but with parallel execution. + + Will essentially return + res = [ function(arg) # default + function(*arg) # if star_args is True + function(**arg) # if kw_args is True + for arg in args] + + Note: + the first elements of args will not be parallelized. + This can be useful for debugging. + """ + # Determine the number of workers + while workers <= 0: + workers += cpu_count() + + # Convert args to an iterable + try: + n_args_parallel = len(args) - front_num + except TypeError: + n_args_parallel = None + args = iter(args) + + # Sequential execution for the first few elements (useful for debugging) + front = [] + while len(front) < front_num: + try: + a = next(args) + except StopIteration: + return front # end of the iterable + front.append( + function(*a) if star_args else function(**a) if kw_args else function(a) + ) + + # Parallel execution using multiprocessing.dummy + out = [] + with Pool(workers) as pool: + if star_args: + map_func = pool.imap if ordered_res else pool.imap_unordered + futures = map_func(starcall, [(function, a) for a in args]) + elif kw_args: + map_func = pool.imap if ordered_res else pool.imap_unordered + futures = map_func(starstarcall, [(function, a) for a in args]) + else: + map_func = pool.imap if ordered_res else pool.imap_unordered + futures = map_func(function, args) + # Track progress with tqdm + for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): + out.append(f) + return front + out + + +def cuda_parallel_threads( + function, + args, + workers=0, + star_args=False, + kw_args=False, + front_num=1, + Pool=TorchPool, + ordered_res=True, + **tqdm_kw, +): + """ + Parallel execution of a function using torch.multiprocessing with CUDA support. + This is the CUDA variant of the parallel_threads function. + """ + # Set the start method for multiprocessing + set_start_method("spawn", force=True) + + # Determine the number of workers + while workers <= 0: + workers += torch.multiprocessing.cpu_count() + + # Convert args to an iterable + try: + n_args_parallel = len(args) - front_num + except TypeError: + n_args_parallel = None + args = iter(args) + + # Sequential execution for the first few elements (useful for debugging) + front = [] + while len(front) < front_num: + try: + a = next(args) + except StopIteration: + return front # End of the iterable + front.append( + function(*a) if star_args else function(**a) if kw_args else function(a) + ) + + # Parallel execution using torch.multiprocessing + out = [] + with Pool(workers) as pool: + if star_args: + map_func = pool.imap if ordered_res else pool.imap_unordered + futures = map_func(starcall, [(function, a) for a in args]) + elif kw_args: + map_func = pool.imap if ordered_res else pool.imap_unordered + futures = map_func(starstarcall, [(function, a) for a in args]) + else: + map_func = pool.imap if ordered_res else pool.imap_unordered + futures = map_func(function, args) + # Track progress with tqdm + for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): + out.append(f) + return front + out + + +def parallel_processes(*args, **kwargs): + """Same as parallel_threads, with processes""" + import multiprocessing as mp + + kwargs["Pool"] = mp.Pool + return parallel_threads(*args, **kwargs) + + +def starcall(args): + """convenient wrapper for Process.Pool""" + function, args = args + return function(*args) + + +def starstarcall(args): + """convenient wrapper for Process.Pool""" + function, args = args + return function(**args) diff --git a/mapanything/utils/timing.py b/mapanything/utils/timing.py new file mode 100644 index 0000000000000000000000000000000000000000..078f587fcf612361584b24a3aeadab8f4446c2da --- /dev/null +++ b/mapanything/utils/timing.py @@ -0,0 +1,309 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utility functions for timing code blocks +""" + +import time +from contextlib import ContextDecorator + +import numpy as np + + +class BlockTimeManager: + """ + Manages a collection of timers and their formatting options. + + This class serves as a central registry for Timer objects, allowing them to be + accessed by name and maintaining their formatting preferences. + + Attributes: + timers (dict): Dictionary mapping timer names to Timer objects + timer_fmts (dict): Dictionary mapping timer names to their display formats + window_size (int): Default window size for calculating windowed averages + buf_size (int): Default buffer size for storing timing measurements + """ + + def __init__(self, window_size=10, buf_size=100000): + self.timers = dict() + self.timer_fmts = dict() + self.window_size = window_size + self.buf_size = buf_size + + +btm = BlockTimeManager(window_size=100000) + + +class Timer: + """ + Core timing class that tracks execution times. + + This class provides the fundamental timing functionality, storing timing measurements + and calculating various statistics. + + Attributes: + name (str): Identifier for this timer + buf_size (int): Maximum number of timing measurements to store + window_size (int): Number of most recent measurements to use for windowed statistics + measures_arr (numpy.ndarray): Array storing start and end times of measurements + current_start (float or None): Start time of current measurement + current_end (float or None): End time of current measurement + """ + + def __init__(self, name, window_size, buf_size=100000): + self.name = name + self.buf_size = buf_size + self.window_size = window_size + self.init() + + def init(self): + """Initialize or reset the timer's state.""" + self.measures_arr = np.empty((0, 2)) # LIFO + self.current_start = None + self.current_end = None + + def reset(self): + """Reset the timer to its initial state.""" + self.init() + + def tic(self): + """Start a new timing measurement.""" + if self.current_start is not None: + # another tic executed before a toc + self.toc() + self.current_start = time.perf_counter() + + def toc(self): + """End the current timing measurement.""" + self.current_end = time.perf_counter() + self._add_current_measure() + + def _add_current_measure(self): + """Add the current timing measurement to the measurements array.""" + self.measures_arr = np.concatenate( + [ + np.array([[self.current_start, self.current_end]]), + self.measures_arr[: self.buf_size], + ] + ) + self.current_start = None + self.current_end = None + + @property + def avg(self) -> float: + """Calculate the average execution time across all measurements.""" + return np.mean(self.measures_arr[:, 1] - self.measures_arr[:, 0]) + + @property + def wavg(self) -> float: + """Calculate the windowed average execution time using the most recent measurements.""" + return np.mean( + self.measures_arr[: self.window_size, 1] + - self.measures_arr[: self.window_size, 0] + ) + + @property + def max(self) -> float: + """Return the maximum execution time.""" + return np.max(self.measures_arr[:, 1] - self.measures_arr[:, 0]) + + @property + def min(self) -> float: + """Return the minimum execution time.""" + return np.min(self.measures_arr[:, 1] - self.measures_arr[:, 0]) + + @property + def total(self) -> float: + """Return the total execution time across all measurements.""" + return np.sum(self.measures_arr[:, 1] - self.measures_arr[:, 0]) + + @property + def latest(self) -> float: + """Return the most recent execution time.""" + return self.measures_arr[0, 1] - self.measures_arr[0, 0] + + @property + def median(self) -> float: + """Return the median execution time.""" + return np.median(self.measures_arr[:, 1] - self.measures_arr[:, 0]) + + @property + def var(self) -> float: + """Return the variance of execution times.""" + return np.var(self.measures_arr[:, 1] - self.measures_arr[:, 0]) + + +class BlockTimer(ContextDecorator): + """ + A context manager and decorator for timing code blocks. + + This class provides a convenient interface for timing code execution, either as a + context manager (with statement) or as a decorator. It uses the Timer class for + the actual timing functionality. + + Attributes: + name (str): Identifier for this timer + fmt (str or None): Format string for displaying timing information + timer (Timer): The underlying Timer object + num_calls (int): Number of times this timer has been called + """ + + @staticmethod + def timers(): + """Return a list of all registered timer names.""" + return list(btm.timers.keys()) + + def __init__(self, name, fmt=None, window_size=100): + self.name = name + if name in btm.timers: + self.timer = btm.timers[name] + # restore format + self.fmt = fmt if fmt is not None else btm.timer_fmts[name] + else: + self.timer = Timer(name, btm.window_size, btm.buf_size) + btm.timers[name] = self.timer + btm.timer_fmts[name] = fmt + self.timer.window_size = window_size + self._default_fmt = "[{name}] num: {num} latest: {latest:.4f} --wind_avg: {wavg:.4f} -- avg: {avg:.4f} --var: {var:.4f} -- total: {total:.4f}" + if fmt == "default": + self.fmt = self._default_fmt + # extend here for new formats + else: + self.fmt = None + + self.num_calls = 0 + + def __enter__(self) -> "Timer": + """Start timing when entering a context.""" + self.tic() + return self + + def __exit__(self, *args): + """End timing when exiting a context and optionally display results.""" + self.toc() + if self.fmt is not None: + print(str(self)) + + def __str__(self) -> str: + """Return a string representation of the timer.""" + return self.display() + + def reset(self): + """Reset the timer and call counter.""" + self.timer.reset() + self.num_calls = 0 + + def display(self, fmt=None): + """ + Format and return timing information. + + Args: + fmt (str, optional): Format string to use. If None, uses the timer's format. + + Returns: + str: Formatted timing information + """ + if fmt is None: + if self.fmt is not None: + fmt = self.fmt + else: + fmt = self._default_fmt + return fmt.format( + name=self.name, + num=self.num_calls, + latest=self.latest, + wavg=self.wavg, + avg=self.avg, + var=self.var, + total=self.total, + ) + + def tic(self): + """Start a new timing measurement and increment the call counter.""" + self.timer.tic() + self.num_calls += 1 + + def toc(self, display=False): + """ + End the current timing measurement. + + Args: + display (bool): Whether to return a formatted display string + + Returns: + str or None: Formatted timing information if display is True + """ + self.timer.toc() + if display: + return self.display() + + @property + def latest(self) -> float: + """Return the most recent execution time.""" + return self.timer.latest + + @property + def avg(self) -> float: + """Return the average execution time.""" + return self.timer.avg + + @property + def wavg(self) -> float: + """Return the windowed average execution time.""" + return self.timer.wavg + + @property + def max(self) -> float: + """Return the maximum execution time.""" + return self.timer.max + + @property + def min(self) -> float: + """Return the minimum execution time.""" + return self.timer.min + + @property + def total(self) -> float: + """Return the total execution time.""" + return self.timer.total + + @property + def median(self) -> float: + """Return the median execution time.""" + return self.timer.median + + @property + def var(self) -> float: + """Return the variance of execution times.""" + return self.timer.var + + +if __name__ == "__main__": + + @BlockTimer("fct", "default") + def fct(bobo): + time.sleep(0.5) + + fct(2) + + for i in range(10): + with BlockTimer("affe", "default"): + time.sleep(0.1) + for i in range(1000): + with BlockTimer("test", None): + time.sleep(0.001) + + # BlockTimer("test").display = f"""avg: {BlockTimer("test").avg} total: {BlockTimer("test").total}""" + # print(str(BlockTimer("test"))) + + print(BlockTimer("test")) + BlockTimer("test").tic() + BlockTimer("t2", "default").tic() + time.sleep(0.4) + print(BlockTimer("t2").toc(True)) + + time.sleep(0.4) + print(BlockTimer("test").toc(True)) diff --git a/mapanything/utils/train_tools.py b/mapanything/utils/train_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..3f824e5cc17c1bd00724e5825fb18bcb4c334e11 --- /dev/null +++ b/mapanything/utils/train_tools.py @@ -0,0 +1,983 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utility functions for training deep learning models, particularly focused on distributed training, +metric logging, and gradient handling. + +This module provides tools for: +- Tracking and logging metrics during training +- Setting up distributed training environments +- Handling gradient scaling and normalization +- Managing learning rates and parameter groups +- Saving and loading model checkpoints + +References: CroCo (https://github.com/naver/croco) +""" + +import builtins +import datetime +import json +import math +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +from torch import inf + + +class SmoothedValue(object): + """ + Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + """ + Logger for tracking and displaying training metrics. + + This class maintains a collection of metrics during training, provides + methods to update them, and formats them for display. It also handles + synchronization of metrics across processes in distributed training. + """ + + def __init__(self, delimiter="\t", print_per_view_stats=False): + """ + Initialize the MetricLogger. + + Args: + delimiter (str, optional): Delimiter for formatting output. Defaults to "\t". + print_per_view_stats (bool, optional): Whether to print per-view statistics. Defaults to False. + """ + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + self.print_per_view_stats = print_per_view_stats + + def update(self, **kwargs): + """ + Update metrics with new values. + + Args: + **kwargs: Key-value pairs where keys are metric names and values are metric values + Values can be tensors or numbers + + Raises: + AssertionError: If a value is not a float or int after conversion from tensor + """ + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + """ + Get a meter by attribute name. + + This allows accessing meters as attributes of the logger. + + Args: + attr (str): Name of the attribute to get + + Returns: + SmoothedValue: The meter corresponding to the attribute name + + Raises: + AttributeError: If the attribute doesn't exist as a meter or regular attribute + """ + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + """ + Format all metrics as a string. + + Returns: + str: Formatted string containing all metrics + """ + loss_str = [] + for name, meter in self.meters.items(): + # Skip printing per-view stats if not enabled + if not self.print_per_view_stats and "view" in name: + continue + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + """ + Synchronize metrics across processes in distributed training. + + This method calls synchronize_between_processes on each meter to + ensure consistent values across all processes. + """ + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + """ + Add a custom meter to the logger. + + Args: + name (str): Name of the meter + meter (SmoothedValue): The meter to add + """ + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, max_iter=None): + """ + Log metrics at regular intervals while iterating. + + This method wraps an iterable and logs metrics every print_freq iterations. + It also tracks iteration time, data loading time, and memory usage. + + Args: + iterable: Iterable to iterate over (typically a data loader) + print_freq (int): How often to log metrics (in iterations) + header (str, optional): Header string to print before metrics. Defaults to None. + max_iter (int, optional): Maximum number of iterations. Defaults to None. + + Yields: + object: Items from the original iterable + """ + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable) + space_fmt = ":" + str(len(str(len_iterable))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for it, obj in enumerate(iterable): + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len_iterable - 1: + eta_seconds = iter_time.global_avg * (len_iterable - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len_iterable, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len_iterable, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + if max_iter and it >= max_iter: + break + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len_iterable + ) + ) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process. + + It replaces the built-in print function with a custom version that only prints + when the current process is the master process or when explicitly forced. + + Args: + is_master (bool): Whether the current process is the master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + # force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print("[{}] ".format(now), end="") # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + """ + Check if distributed training is available and initialized. + + Returns: + bool: True if distributed training is available and initialized, False otherwise + """ + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + """ + Get the number of processes in the distributed training group. + + Returns: + int: Number of processes in the distributed group, or 1 if not using distributed training + """ + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + """ + Get the rank of the current process in the distributed training group. + + Returns: + int: Rank of the current process, or 0 if not using distributed training + """ + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + """ + Check if the current process is the main process (rank 0). + + Returns: + bool: True if the current process is the main process, False otherwise + """ + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + """ + Save a PyTorch object only on the master process. + + This function is useful in distributed training to avoid multiple processes + trying to save the same file simultaneously. + + Args: + *args: Positional arguments to pass to torch.save() + **kwargs: Keyword arguments to pass to torch.save() + """ + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + """ + Initialize distributed training mode. + + This function sets up the distributed training environment based on environment + variables and command-line arguments. It initializes the process group, + sets the appropriate device, and configures printing for the distributed setup. + + Args: + args: Arguments object containing distributed training configuration. + Expected to have attributes like dist_url, and will be modified + to include rank, world_size, gpu, and distributed flag. + """ + nodist = args.nodist if hasattr(args, "nodist") else False + if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not nodist: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + else: + print("Not using distributed mode") + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}, gpu {}".format( + args.rank, args.dist_url, args.gpu + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + """ + A gradient scaler that handles gradient scaling and norm computation for mixed precision training. + + This class wraps PyTorch's GradScaler to provide additional functionality for gradient norm tracking + and clipping during mixed precision training. + """ + + state_dict_key = "amp_scaler" + + def __init__(self, enabled=True): + """Initialize the scaler. + + Args: + enabled (bool): Whether to enable gradient scaling. Default: True + """ + self._scaler = torch.GradScaler("cuda", enabled=enabled) + + def __call__( + self, + loss, + optimizer, + clip_grad=None, + parameters=None, + create_graph=False, + update_grad=True, + ): + """Scales loss and performs backward pass with optional gradient clipping. + + Args: + loss: The loss to backpropagate + optimizer: The optimizer being used + clip_grad: Max norm for gradient clipping. None means no clipping + parameters: Model parameters or list of parameters for gradient norm computation + create_graph: Whether to create graph during backward pass + update_grad: Whether to update gradients + + Returns: + norm: The gradient norm if computed, else None. Returns list of norms if parameters is a list. + """ + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_( + optimizer + ) # unscale the gradients of optimizer's assigned params in-place + if isinstance(parameters, (list, tuple)): + norm = [ + torch.nn.utils.clip_grad_norm_(p, clip_grad) for p in parameters + ] + else: + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + """Returns the state dict of the underlying scaler. + + Returns: + dict: The state dict of the gradient scaler + """ + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + """Loads the state dict into the underlying scaler. + + Args: + state_dict: The state dict to load + """ + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + """ + Calculate the gradient norm of parameters. + + This function computes the norm of gradients for a set of parameters. It can handle + both single parameter groups and multiple parameter groups (list/tuple of parameters). + + Args: + parameters: A tensor or iterable of tensors or iterable of iterables of tensors + containing model parameters for which to compute gradient norms + norm_type (float): Type of norm to use (e.g., 2.0 for L2 norm, inf for infinity norm) + + Returns: + torch.Tensor: The computed gradient norm. If parameters is a list/tuple of parameter + groups, returns a list of norms, one for each group. + """ + if isinstance(parameters, (list, tuple)): + # If parameters is already a list/tuple, process each parameter group + all_norms = [] + for params in parameters: + if isinstance(params, torch.Tensor): + params = [params] + params = [p for p in params if p.grad is not None] + if len(params) > 0: + device = params[0].grad.device + if norm_type == inf: + group_norm = max( + p.grad.detach().abs().max().to(device) for p in params + ) + else: + group_norm = torch.norm( + torch.stack( + [ + torch.norm(p.grad.detach(), norm_type).to(device) + for p in params + ] + ), + norm_type, + ) + else: + group_norm = torch.tensor(0.0) + all_norms.append(group_norm) + return all_norms + + # Original logic for single parameter group + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.0) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] + ), + norm_type, + ) + return total_norm + + +def save_model( + args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None +): + """ + Save model checkpoint to disk. + + This function saves the model state, optimizer state, loss scaler state, + training arguments, current epoch, and optionally the best metric value so far. + The checkpoint is only saved on the master process in distributed training. + + Args: + args: Arguments containing output directory information + epoch (int): Current training epoch + model_without_ddp (torch.nn.Module): Model without DistributedDataParallel wrapper + optimizer (torch.optim.Optimizer): Optimizer instance + loss_scaler: Gradient scaler for mixed precision training + fname (str, optional): Custom filename suffix. If None, uses the epoch number. Defaults to None. + best_so_far (float, optional): Best metric value achieved so far. Defaults to None. + """ + output_dir = Path(args.output_dir) + if fname is None: + fname = str(epoch) + checkpoint_path = output_dir / ("checkpoint-%s.pth" % fname) + to_save = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "scaler": loss_scaler.state_dict(), + "args": args, + "epoch": epoch, + } + if best_so_far is not None: + to_save["best_so_far"] = best_so_far + print(f">> Saving model to {checkpoint_path} ...") + save_on_master(to_save, checkpoint_path) + + +def load_model(train_args, model_without_ddp, optimizer, loss_scaler): + """ + Load model checkpoint from disk or URL. + + This function loads a saved checkpoint, restoring the model state, optimizer state, + loss scaler state, and training epoch. It can load from a local file or a URL. + + Args: + train_args: Training arguments containing resume information + model_without_ddp (torch.nn.Module): Model without DistributedDataParallel wrapper + optimizer (torch.optim.Optimizer): Optimizer instance + loss_scaler: Gradient scaler for mixed precision training + + Returns: + float or None: Best metric value from the checkpoint if available, otherwise None + """ + train_args.start_epoch = 0 + best_so_far = None + if train_args.resume and train_args.resume_ckpt is not None: + if train_args.resume_ckpt.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url( + train_args.resume_ckpt, map_location="cpu", check_hash=True + ) + else: + checkpoint = torch.load( + train_args.resume_ckpt, map_location="cpu", weights_only=False + ) + print("Resume checkpoint %s" % train_args.resume_ckpt) + model_without_ddp.load_state_dict(checkpoint["model"], strict=False) + train_args.start_epoch = checkpoint["epoch"] + 1 + optimizer.load_state_dict(checkpoint["optimizer"]) + if "scaler" in checkpoint: + loss_scaler.load_state_dict(checkpoint["scaler"]) + if "best_so_far" in checkpoint: + best_so_far = checkpoint["best_so_far"] + print(" & best_so_far={:g}".format(best_so_far)) + else: + print("") + print( + "With optim & sched! start_epoch={:d}".format(train_args.start_epoch), + end="", + ) + return best_so_far + + +def all_reduce_mean(x): + """ + Compute the mean of a value across all processes in distributed training. + + This function takes a value, reduces it across all processes using all_reduce, + and returns the mean value. + + Args: + x: The value to reduce (typically a scalar) + + Returns: + float: The mean value across all processes + """ + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + + +def _replace(text, src, tgt, rm=""): + """ + Advanced string replacement utility. + + Given a text: + - replace all elements in src by the corresponding element in tgt + - remove all elements in rm + + Args: + text (str): The input text to modify + src (str): String of characters to replace + tgt (str): String of replacement characters (must be same length as src or length 1) + rm (str, optional): String of characters to remove. Defaults to "". + + Returns: + str: The modified text after replacements and removals + + Raises: + AssertionError: If src and tgt have different lengths (unless tgt has length 1) + """ + if len(tgt) == 1: + tgt = tgt * len(src) + assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len" + for s, t in zip(src, tgt): + text = text.replace(s, t) + for c in rm: + text = text.replace(c, "") + return text + + +def filename(obj): + """ + Transform a Python object or command into a proper filename. + + This function converts a Python object or command string into a valid filename + by replacing special characters and ensuring the filename is not too long. + + Special replacements: + - \1 gets replaced by slash '/' + - \2 gets replaced by comma ',' + + Args: + obj: The Python object or string to convert to a filename + + Returns: + str: A valid filename derived from the input object + + Raises: + AssertionError: If any part of the resulting path is longer than 256 characters + """ + if not isinstance(obj, str): + obj = repr(obj) + obj = str(obj).replace("()", "") + obj = _replace(obj, "_,(*/\1\2", "-__x%/,", rm=" )'\"") + assert all(len(s) < 256 for s in obj.split(os.sep)), ( + "filename too long (>256 characters):\n" + obj + ) + return obj + + +def compute_effective_lrs(train_args): + """ + Compute the effective learning rates based on batch size scaling. + + This function calculates the effective learning rates for the main model and + any submodules based on the effective batch size (accounting for gradient accumulation + and distributed training) and the base learning rates. + + Args: + train_args: Training arguments containing batch size, accumulation iterations, + learning rates, and submodule configurations + + Returns: + train_args: Updated training arguments with computed effective learning rates + """ + + # Compute the effective batch size + eff_batch_size = train_args.batch_size * train_args.accum_iter * get_world_size() + print("Accumulate grad iterations: %d" % train_args.accum_iter) + print("Effective batch size: %d" % eff_batch_size) + # Compute the effective default learning rate + if train_args.lr is None: # only base_lr is specified + train_args.lr = train_args.blr * math.sqrt( + eff_batch_size / train_args.base_eff_batch_size + ) + print( + f"Base default lr for effective batch size {eff_batch_size}: %.2e" + % (train_args.lr * math.sqrt(train_args.base_eff_batch_size / eff_batch_size)) + ) + print("Actual default lr: %.2e" % train_args.lr) + for submodule, config in train_args.submodule_configs.items(): + if config.get("lr") is None: # only base_lr is specified + config["lr"] = config["blr"] * math.sqrt( + eff_batch_size / train_args.base_eff_batch_size + ) + print( + f"Submodule {submodule} base lr for effective batch size {eff_batch_size}: %.2e" + % ( + config["lr"] + * math.sqrt(train_args.base_eff_batch_size / eff_batch_size) + ) + ) + print(f"Submodule {submodule} actual lr: %.2e" % config["lr"]) + + return train_args + + +def get_parameter_groups( + model, + lr, + weight_decay, + skip_list=[], + submodule_configs=None, + warn_not_in_submodule=False, +): + """ + Get parameter groups for optimizer with customized learning rates and weight decay. + + This function organizes model parameters into groups for the optimizer, allowing + different learning rates and weight decay values for different parts of the model. + Parameters are grouped by: + 1. Whether they should have weight decay applied (bias terms and 1D tensors typically don't) + 2. Which submodule they belong to (if submodule_configs is provided) + + Args: + model (torch.nn.Module): Model to get parameter groups for + lr (float): Default learning rate for parameters not in submodule_configs + weight_decay (float): Default weight decay for parameters not in submodule_configs + skip_list (list): List of parameter names to skip weight decay for + submodule_configs (dict, optional): Dictionary mapping submodule prefixes to configs + with 'lr' and 'weight_decay' keys + warn_not_in_submodule (bool, optional): Whether to warn if a parameter does not + belong to any submodule. Defaults to False. + + Returns: + tuple: A tuple containing: + - parameter_group_vars (list): List of parameter groups for optimizer + - parameter_group_name_to_idx_map (dict): Mapping from submodule name to parameter group indices + - parameter_group_idx_to_name_map (dict): Mapping from parameter group index to submodule name + """ + + if submodule_configs is None: + submodule_configs = {} + + parameter_group_names = {} + parameter_group_vars = {} + parameter_group_name_to_idx_map = {} + parameter_group_idx_to_name_map = {} + mapping_index = 0 + + for name, param in model.named_parameters(): + # Skip frozen parameters + if not param.requires_grad: + continue + + # Determine the submodule this parameter belongs to + submodule_name = None + for submodule, config in submodule_configs.items(): + if name.startswith(submodule): + submodule_name = submodule + break + + if submodule_name: + config = submodule_configs[submodule_name] + this_weight_decay = config.get("weight_decay", weight_decay) + this_lr = config.get("lr", lr) + # Freeze the parameters if lr is 0 + if this_lr == 0: + param.requires_grad = False + continue + else: + this_weight_decay = weight_decay + this_lr = lr + if warn_not_in_submodule and submodule_configs is not None: + print( + f"Warning: Parameter {name} does not belong to any submodule in {submodule_configs.keys()}." + ) + + # Assign weight decay values + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + group_name = f"{submodule_name}_no_decay" if submodule_name else "no_decay" + this_weight_decay = 0.0 + else: + group_name = f"{submodule_name}_decay" if submodule_name else "decay" + + if group_name not in parameter_group_names: + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "lr": this_lr, + "params": [], + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "lr": this_lr, + "params": [], + } + submodule_name_mapping = submodule_name if submodule_name else "default" + if submodule_name_mapping not in parameter_group_name_to_idx_map: + parameter_group_name_to_idx_map[submodule_name_mapping] = [ + mapping_index + ] + else: + parameter_group_name_to_idx_map[submodule_name_mapping].append( + mapping_index + ) + parameter_group_idx_to_name_map[mapping_index] = submodule_name_mapping + mapping_index += 1 + + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + + # Print the parameter groups + print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + + return ( + list(parameter_group_vars.values()), + parameter_group_name_to_idx_map, + parameter_group_idx_to_name_map, + ) + + +def adjust_learning_rate( + optimizer, + epoch, + train_args, + parameter_group_idx_to_name_map, + submodule_configs=None, +): + """ + Adjust the learning rate based on the schedule type and current epoch. + + This function updates the learning rates for all parameter groups in the optimizer + according to the specified learning rate schedule. Different submodules can have + different learning rate schedules. + + Currently supported schedule types: + - linear_warmup_half_cycle_cosine_decay: Linear warmup followed by cosine decay + + Args: + optimizer (torch.optim.Optimizer): The optimizer to update + epoch (int): Current training epoch + train_args: Training arguments containing schedule type, warmup epochs, etc. + parameter_group_idx_to_name_map (dict): Mapping from parameter group index to submodule name + submodule_configs (dict, optional): Dictionary of submodule-specific configurations + for learning rate schedules + + Raises: + ValueError: If an unsupported schedule type is specified + """ + + if submodule_configs is None: + submodule_configs = {} + + for group_num, param_group in enumerate(optimizer.param_groups): + submodule_name = parameter_group_idx_to_name_map.get(group_num) + + if submodule_name in submodule_configs: + config = submodule_configs[submodule_name] + lr = config.get("lr", train_args.lr) + warmup_epochs = config.get("warmup_epochs", train_args.warmup_epochs) + min_lr = config.get("min_lr", train_args.min_lr) + schedule_type = config.get("schedule_type", train_args.schedule_type) + else: + lr = train_args.lr + warmup_epochs = train_args.warmup_epochs + min_lr = train_args.min_lr + schedule_type = train_args.schedule_type + + if schedule_type == "linear_warmup_half_cycle_cosine_decay": + if epoch < warmup_epochs: + lr = lr * epoch / warmup_epochs + else: + lr = min_lr + (lr - min_lr) * 0.5 * ( + 1.0 + + math.cos( + math.pi + * (epoch - warmup_epochs) + / (train_args.epochs - warmup_epochs) + ) + ) + else: + raise ValueError(f"Schedule type {schedule_type} not implemented") + + param_group["lr"] = lr + + +def debug_after_backward( + model, + check_missing_gradients=True, + check_gradient_mismatch=False, + target_size=(256, 256, 1, 1), + target_stride=(256, 1, 256, 256), +): + """ + Debugging function to check for gradient issues after backward pass. + + This function performs two types of gradient debugging: + 1. Gradient mismatch: Checks for parameters with specific gradient shapes and strides + that might indicate incorrect gradient computation. + 2. Missing gradients: Identifies parameters that require gradients but didn't receive any. + + Args: + model (torch.nn.Module): The model to check gradients for + check_missing_gradients (bool, optional): Whether to check for missing gradients. Defaults to True. + check_gradient_mismatch (bool, optional): Whether to check for gradient mismatches. Defaults to False. + target_size (tuple, optional): Target tensor size to check for gradient mismatch. Defaults to (256, 256, 1, 1). + target_stride (tuple, optional): Target tensor stride to check for gradient mismatch. Defaults to (256, 1, 256, 256). + """ + # Debug for missing gradients + if check_missing_gradients: + missing_grad_params = [] + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is None: + missing_grad_params.append(name) + + if missing_grad_params: + print("Parameters requiring gradients but missing gradients:") + for name in missing_grad_params: + print(f" - {name}") + else: + print("All parameters requiring gradients received gradients!") + + # Debug for gradient mismatch + if check_gradient_mismatch: + for name, param in model.named_parameters(): + grad = param.grad + if grad is None: + continue + if grad.size() == target_size and grad.stride() == target_stride: + print(f"Found parameter with incorrect gradient: '{name}'") + print(f"Gradient shape: {grad.size()}, strides: {grad.stride()}") diff --git a/mapanything/utils/viz.py b/mapanything/utils/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..a25bc91b35809fbf6eac9bb39d74540fed56fd00 --- /dev/null +++ b/mapanything/utils/viz.py @@ -0,0 +1,266 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utility functions for visualization +""" + +from argparse import ArgumentParser, Namespace +from distutils.util import strtobool + +import numpy as np +import rerun as rr +import trimesh + +from mapanything.utils.hf_utils.viz import image_mesh + + +def log_posed_rgbd_data_to_rerun( + image, depthmap, pose, intrinsics, base_name, mask=None +): + """ + Log camera and image data to Rerun visualization tool. + + Parameters + ---------- + image : numpy.ndarray + RGB image to be logged + depthmap : numpy.ndarray + Depth map corresponding to the image + pose : numpy.ndarray + 4x4 camera pose matrix with rotation (3x3) and translation (3x1) + intrinsics : numpy.ndarray + Camera intrinsic matrix + base_name : str + Base name for the logged entities in Rerun + mask : numpy.ndarray, optional + Optional segmentation mask for the depth image + """ + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if mask is not None: + rr.log( + f"{base_name}/pinhole/depth_mask", + rr.SegmentationImage(mask), + ) + + +def str2bool(v): + return bool(strtobool(v)) + + +def script_add_rerun_args(parser: ArgumentParser) -> None: + """ + Add common Rerun script arguments to `parser`. + + Change Log from https://github.com/rerun-io/rerun/blob/29eb8954b08e59ff96943dc0677f46f7ea4ea734/rerun_py/rerun_sdk/rerun/script_helpers.py#L65: + - Added default portforwarding url for ease of use + - Update parser types + + Parameters + ---------- + parser : ArgumentParser + The parser to add arguments to. + + Returns + ------- + None + """ + parser.add_argument( + "--headless", + type=str2bool, + nargs="?", + const=True, + default=True, + help="Don't show GUI", + ) + parser.add_argument( + "--connect", + dest="connect", + type=str2bool, + nargs="?", + const=True, + default=True, + help="Connect to an external viewer", + ) + parser.add_argument( + "--serve", + dest="serve", + type=str2bool, + nargs="?", + const=True, + default=False, + help="Serve a web viewer (WARNING: experimental feature)", + ) + parser.add_argument( + "--url", + type=str, + default="rerun+http://127.0.0.1:2004/proxy", + help="Connect to this HTTP(S) URL", + ) + parser.add_argument( + "--save", type=str, default=None, help="Save data to a .rrd file at this path" + ) + parser.add_argument( + "-o", + "--stdout", + dest="stdout", + action="store_true", + help="Log data to standard output, to be piped into a Rerun Viewer", + ) + + +def init_rerun_args( + headless=True, + connect=True, + serve=False, + url="rerun+http://127.0.0.1:2004/proxy", + save=None, + stdout=False, +) -> Namespace: + """ + Initialize common Rerun script arguments. + + Parameters + ---------- + headless : bool, optional + Don't show GUI, by default True + connect : bool, optional + Connect to an external viewer, by default True + serve : bool, optional + Serve a web viewer (WARNING: experimental feature), by default False + url : str, optional + Connect to this HTTP(S) URL, by default rerun+http://127.0.0.1:2004/proxy + save : str, optional + Save data to a .rrd file at this path, by default None + stdout : bool, optional + Log data to standard output, to be piped into a Rerun Viewer, by default False + + Returns + ------- + Namespace + The parsed arguments. + """ + rerun_args = Namespace() + rerun_args.headless = headless + rerun_args.connect = connect + rerun_args.serve = serve + rerun_args.url = url + rerun_args.save = save + rerun_args.stdout = stdout + + return rerun_args + + +def predictions_to_glb( + predictions, + as_mesh=True, +) -> trimesh.Scene: + """ + Converts predictions to a 3D scene represented as a GLB file. + + Args: + predictions (dict): Dictionary containing model predictions with keys: + - world_points: 3D point coordinates (V, H, W, 3) + - images: Input images (V, H, W, 3) + - final_masks: Validity masks (V, H, W) + as_mesh (bool): Represent the data as a mesh instead of point cloud (default: True) + + Returns: + trimesh.Scene: Processed 3D scene containing point cloud/mesh and cameras + + Raises: + ValueError: If input predictions structure is invalid + """ + if not isinstance(predictions, dict): + raise ValueError("predictions must be a dictionary") + + # Get the world frame points and images from the predictions + pred_world_points = predictions["world_points"] + images = predictions["images"] + + # Get the points and rgb + vertices_3d = pred_world_points.reshape(-1, 3) + # Handle different image formats - check if images need transposing + if images.ndim == 4 and images.shape[1] == 3: # NCHW format + colors_rgb = np.transpose(images, (0, 2, 3, 1)) + else: # Assume already in NHWC format + colors_rgb = images + colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) + + # Initialize a 3D scene + scene_3d = trimesh.Scene() + + # Add point cloud data to the scene + if as_mesh: + # Multi-frame case - create separate meshes for each frame + for frame_idx in range(pred_world_points.shape[0]): + H, W = pred_world_points.shape[1:3] + + # Get data for this frame + frame_points = pred_world_points[frame_idx] + frame_final_mask = predictions["final_masks"][frame_idx] + + # Get frame image + if images.ndim == 4 and images.shape[1] == 3: # NCHW format + frame_image = np.transpose(images[frame_idx], (1, 2, 0)) + else: # Assume already in HWC format + frame_image = images[frame_idx] + frame_image *= 255 + + # Create mesh for this frame + faces, vertices, vertex_colors = image_mesh( + frame_points * np.array([1, -1, 1], dtype=np.float32), + frame_image / 255.0, + mask=frame_final_mask, + tri=True, + return_indices=False, + ) + vertices = vertices * np.array([1, -1, 1], dtype=np.float32) + + # Create trimesh object for this frame + frame_mesh = trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_colors=(vertex_colors * 255).astype(np.uint8), + process=False, + ) + scene_3d.add_geometry(frame_mesh) + else: + final_masks = predictions["final_masks"].reshape(-1) + vertices_3d = vertices_3d[final_masks].copy() + colors_rgb = colors_rgb[final_masks].copy() + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) + scene_3d.add_geometry(point_cloud_data) + + # Apply 180° rotation around X-axis to fix orientation (upside-down issue) + rotation_matrix_x = trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]) + scene_3d.apply_transform(rotation_matrix_x) + + return scene_3d diff --git a/mapanything/utils/wai/__init__.py b/mapanything/utils/wai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ded41113638f6e7bf5d83b5f5086b31be52e3b --- /dev/null +++ b/mapanything/utils/wai/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +This utils module contains PORTAGE of wai-core scripts/methods for MapAnything. +""" diff --git a/mapanything/utils/wai/basic_dataset.py b/mapanything/utils/wai/basic_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fde2eca784e7a168e8647e44ac6268f540133fde --- /dev/null +++ b/mapanything/utils/wai/basic_dataset.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from pathlib import Path +from typing import Any + +import torch +from box import Box + +from mapanything.utils.wai.core import get_frame_index, load_data, load_frame +from mapanything.utils.wai.ops import stack +from mapanything.utils.wai.scene_frame import get_scene_frame_names + + +class BasicSceneframeDataset(torch.utils.data.Dataset): + """Basic wai dataset to iterative over frames of scenes""" + + @staticmethod + def collate_fn(batch: list[dict[str, Any]]) -> dict[str, Any]: + return stack(batch) + + def __init__( + self, + cfg: Box, + ): + """ + Initialize the BasicSceneframeDataset. + + Args: + cfg (Box): Configuration object containing dataset parameters including: + - root: Root directory containing scene data + - frame_modalities: List of modalities to load for each frame + - key_remap: Optional dictionary mapping original keys to new keys + """ + super().__init__() + self.cfg = cfg + self.root = cfg.root + keyframes = cfg.get("use_keyframes", True) + self.scene_frame_names = get_scene_frame_names(cfg, keyframes=keyframes) + self.scene_frame_list = [ + (scene_name, frame_name) + for scene_name, frame_names in self.scene_frame_names.items() + for frame_name in frame_names + ] + self._scene_cache = {} + + def __len__(self): + """ + Get the total number of scene-frame pairs in the dataset. + + Returns: + int: The number of scene-frame pairs. + """ + return len(self.scene_frame_list) + + def _load_scene(self, scene_name: str) -> dict[str, Any]: + """ + Load scene data for a given scene name. + + Args: + scene_name (str): The name of the scene to load. + + Returns: + dict: A dictionary containing scene data, including scene metadata. + """ + # load scene data + scene_data = {} + scene_data["meta"] = load_data( + Path( + self.root, + scene_name, + self.cfg.get("scene_meta_path", "scene_meta.json"), + ), + "scene_meta", + ) + + return scene_data + + def _load_scene_frame( + self, scene_name: str, frame_name: str | float + ) -> dict[str, Any]: + """ + Load data for a specific frame from a specific scene. + + This method loads scene data if not already cached, then loads the specified frame + from that scene with the modalities specified in the configuration. + + Args: + scene_name (str): The name of the scene containing the frame. + frame_name (str or float): The name/timestamp of the frame to load. + + Returns: + dict: A dictionary containing the loaded frame data with requested modalities. + """ + scene_frame_data = {} + if not (scene_data := self._scene_cache.get(scene_name)): + scene_data = self._load_scene(scene_name) + # for now only cache the last scene + self._scene_cache = {} + self._scene_cache[scene_name] = scene_data + + frame_idx = get_frame_index(scene_data["meta"], frame_name) + + scene_frame_data["scene_name"] = scene_name + scene_frame_data["frame_name"] = frame_name + scene_frame_data["scene_path"] = str(Path(self.root, scene_name)) + scene_frame_data["frame_idx"] = frame_idx + scene_frame_data.update( + load_frame( + Path(self.root, scene_name), + frame_name, + modalities=self.cfg.frame_modalities, + scene_meta=scene_data["meta"], + ) + ) + # Remap key names + for key, new_key in self.cfg.get("key_remap", {}).items(): + if key in scene_frame_data: + scene_frame_data[new_key] = scene_frame_data.pop(key) + + return scene_frame_data + + def __getitem__(self, index: int) -> dict[str, Any]: + """ + Get a specific scene-frame pair by index. + + Args: + index (int): The index of the scene-frame pair to retrieve. + + Returns: + dict: A dictionary containing the loaded frame data with requested modalities. + """ + scene_frame = self._load_scene_frame(*self.scene_frame_list[index]) + return scene_frame diff --git a/mapanything/utils/wai/camera.py b/mapanything/utils/wai/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..8eda00f0f457ec8c53fff501ecde3995f9d006cf --- /dev/null +++ b/mapanything/utils/wai/camera.py @@ -0,0 +1,352 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +This utils script contains PORTAGE of wai-core camera methods for MapAnything. +""" + +from typing import Any + +import numpy as np +import torch +from scipy.spatial.transform import Rotation, Slerp + +from mapanything.utils.wai.ops import get_dtype_device + +# constants regarding camera models +PINHOLE_CAM_KEYS = ["fl_x", "fl_y", "cx", "cy", "h", "w"] +DISTORTION_PARAM_KEYS = [ + "k1", + "k2", + "k3", + "k4", + "p1", + "p2", +] # order corresponds to the OpenCV convention +CAMERA_KEYS = PINHOLE_CAM_KEYS + DISTORTION_PARAM_KEYS + +# Retrieve all available camera models and their associated parameters using pycolmap +CAM_STRS_TO_PARAMS = { + # For PINHOLE we use separate focal length for x and y, even though almost always we + # will have fx=fy. + "PINHOLE": ["fl_x", "fl_y", "cx", "cy"], + # Undistortion supported by OPenCV + "OPENCV": ["fl_x", "fl_y", "cx", "cy", "k1", "k2", "p1", "p2"], + "OPENCV_FISHEYE": ["fl_x", "fl_y", "cx", "cy", "k1", "k2", "k3", "k4"], + # Undistortion supported by pycolmap + "FULL_OPENCV": [ + "fl_x", + "fl_y", + "cx", + "cy", + "k1", + "k2", + "p1", + "p2", + "k3", + "k4", + "k5", + "k6", + ], + "FOV": ["fl_x", "fl_y", "cx", "cy", "omega"], + "THIN_PRISM_FISHEYE": [ + "fl_x", + "fl_y", + "cx", + "cy", + "k1", + "k2", + "p1", + "p2", + "k3", + "k4", + "sx1", + "sy1", + ], + "RAD_TAN_THIN_PRISM_FISHEYE": [ + "fl_x", + "fl_y", + "cx", + "cy", + "k0", + "k1", + "k2", + "k3", + "k4", + "k5", + "p0", + "p1", + "s0", + "s1", + "s2", + "s3", + ], + # Non-OpenCV and non-pycolmap camera models + "EQUIRECTANGULAR": [], # Only width and height needed +} +# This is just an unordered helper list for all cam params and for distortion parameters +# which should never occur for a pinhole camera +ALL_CAM_PARAMS = list(set().union(*CAM_STRS_TO_PARAMS.values())) + ["w", "h"] + + +def interpolate_intrinsics( + frame1: dict[str, Any], + frame2: dict[str, Any], + alpha: float, +) -> dict[str, Any]: + """ + Interpolate camera intrinsics linearly. + Args: + frame1: The first frame dictionary. + frame2: The second frame dictionary. + alpha: Interpolation parameter. alpha = 0 for frame1, alpha = 1 for frame2. + Returns: + frame_inter: dictionary with new intrinsics. + """ + frame_inter = {} + for key in CAMERA_KEYS: + if key in frame1 and key in frame2: + p1 = frame1[key] + p2 = frame2[key] + frame_inter[key] = (1 - alpha) * p1 + alpha * p2 + return frame_inter + + +def interpolate_extrinsics( + matrix1: list | np.ndarray | torch.Tensor, + matrix2: list | np.ndarray | torch.Tensor, + alpha: float, +) -> list | np.ndarray | torch.Tensor: + """ + Interpolate camera extrinsics 4x4 matrices using SLERP. + Args: + matrix1: The first matrix. + matrix2: The second matrix. + alpha: Interpolation parameter. alpha = 0 for matrix1, alpha = 1 for matrix2. + Returns: + matrix: 4x4 interpolated matrix, same type. + Raises: + ValueError: If different type. + """ + if not isinstance(matrix1, type(matrix2)): + raise ValueError("Both matrices should have the same type.") + + dtype, device = get_dtype_device(matrix1) + if isinstance(matrix1, list): + mtype = "list" + matrix1 = np.array(matrix1) + matrix2 = np.array(matrix2) + elif isinstance(matrix1, np.ndarray): + mtype = "numpy" + elif isinstance(matrix1, torch.Tensor): + mtype = "torch" + matrix1 = matrix1.numpy() + matrix2 = matrix2.numpy() + else: + raise ValueError( + "Only list, numpy array and torch tensors are supported as inputs." + ) + + R1 = matrix1[:3, :3] + t1 = matrix1[:3, 3] + R2 = matrix2[:3, :3] + t2 = matrix2[:3, 3] + + # interpolate translation + t = (1 - alpha) * t1 + alpha * t2 + + # interpolate rotations with SLERP + R1_quat = Rotation.from_matrix(R1).as_quat() + R2_quat = Rotation.from_matrix(R2).as_quat() + rotation_slerp = Slerp([0, 1], Rotation(np.stack([R1_quat, R2_quat]))) + R = rotation_slerp(alpha).as_matrix() + matrix_inter = np.eye(4) + + # combine together + matrix_inter[:3, :3] = R + matrix_inter[:3, 3] = t + + if mtype == "list": + matrix_inter = matrix_inter.tolist() + elif mtype == "torch": + matrix_inter = torch.from_numpy(matrix_inter).to(dtype).to(device) + elif mtype == "numpy": + matrix_inter = matrix_inter.astype(dtype) + + return matrix_inter + + +def convert_camera_coeffs_to_pinhole_matrix( + scene_meta, frame, fmt="torch" +) -> torch.Tensor | np.ndarray | list: + """ + Convert camera intrinsics from NeRFStudio format to a 3x3 intrinsics matrix. + + Args: + scene_meta: Scene metadata containing camera parameters + frame: Frame-specific camera parameters that override scene_meta + + Returns: + torch.Tensor: 3x3 camera intrinsics matrix + + Raises: + ValueError: If camera model is not PINHOLE or if distortion coefficients are present + """ + # Check if camera model is supported + camera_model = frame.get("camera_model", scene_meta.get("camera_model")) + if camera_model != "PINHOLE": + raise ValueError("Only PINHOLE camera model supported") + + # Check for unsupported distortion coefficients + if any( + (frame.get(coeff, 0) != 0) or (scene_meta.get(coeff, 0) != 0) + for coeff in DISTORTION_PARAM_KEYS + ): + raise ValueError( + "Pinhole camera does not support radial/tangential distortion -> Undistort first" + ) + + # Extract camera intrinsic parameters + camera_coeffs = {} + for coeff in ["fl_x", "fl_y", "cx", "cy"]: + camera_coeffs[coeff] = frame.get(coeff, scene_meta.get(coeff)) + if camera_coeffs[coeff] is None: + raise ValueError(f"Missing required camera parameter: {coeff}") + + # Create intrinsics matrix + intrinsics = [ + [camera_coeffs["fl_x"], 0.0, camera_coeffs["cx"]], + [0.0, camera_coeffs["fl_y"], camera_coeffs["cy"]], + [0.0, 0.0, 1.0], + ] + if fmt == "torch": + intrinsics = torch.tensor(intrinsics) + elif fmt == "np": + intrinsics = np.array(intrinsics) + + return intrinsics + + +def rotate_pinhole_90degcw( + W: int, H: int, fx: float, fy: float, cx: float, cy: float +) -> tuple[int, int, float, float, float, float]: + """Rotates the intrinsics of a pinhole camera model by 90 degrees clockwise.""" + W_new = H + H_new = W + fx_new = fy + fy_new = fx + cy_new = cx + cx_new = H - 1 - cy + return W_new, H_new, fx_new, fy_new, cx_new, cy_new + + +def _gl_cv_cmat() -> np.ndarray: + cmat = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + return cmat + + +def _apply_transformation( + c2ws: torch.Tensor | np.ndarray, cmat: np.ndarray +) -> torch.Tensor | np.ndarray: + """ + Convert camera poses using a provided conversion matrix. + + Args: + c2ws (torch.Tensor or np.ndarray): Camera poses (batch_size, 4, 4) or (4, 4) + cmat (torch.Tensor or np.ndarray): Conversion matrix (4, 4) + + Returns: + torch.Tensor or np.ndarray: Transformed camera poses (batch_size, 4, 4) or (4, 4) + """ + if isinstance(c2ws, torch.Tensor): + # Clone the input tensor to avoid modifying it in-place + c2ws_transformed = c2ws.clone() + # Apply the conversion matrix to the rotation part of the camera poses + if len(c2ws.shape) == 3: + c2ws_transformed[:, :3, :3] = c2ws_transformed[ + :, :3, :3 + ] @ torch.from_numpy(cmat[:3, :3]).to(c2ws).unsqueeze(0) + else: + c2ws_transformed[:3, :3] = c2ws_transformed[:3, :3] @ torch.from_numpy( + cmat[:3, :3] + ).to(c2ws) + + elif isinstance(c2ws, np.ndarray): + # Clone the input array to avoid modifying it in-place + c2ws_transformed = c2ws.copy() + if len(c2ws.shape) == 3: # batched + # Apply the conversion matrix to the rotation part of the camera poses + c2ws_transformed[:, :3, :3] = np.einsum( + "ijk,lk->ijl", c2ws_transformed[:, :3, :3], cmat[:3, :3] + ) + else: # single 4x4 matrix + # Apply the conversion matrix to the rotation part of the camera pose + c2ws_transformed[:3, :3] = np.dot(c2ws_transformed[:3, :3], cmat[:3, :3]) + + else: + raise ValueError("Input data type not supported.") + + return c2ws_transformed + + +def gl2cv( + c2ws: torch.Tensor | np.ndarray, + return_cmat: bool = False, +) -> torch.Tensor | np.ndarray | tuple[torch.Tensor | np.ndarray, np.ndarray]: + """ + Convert camera poses from OpenGL to OpenCV coordinate system. + + Args: + c2ws (torch.Tensor or np.ndarray): Camera poses (batch_size, 4, 4) or (4, 4) + return_cmat (bool): If True, return the conversion matrix along with the transformed poses + + Returns: + torch.Tensor or np.ndarray: Transformed camera poses (batch_size, 4, 4) or (4, 4) + np.ndarray (optional): Conversion matrix if return_cmat is True + """ + cmat = _gl_cv_cmat() + if return_cmat: + return _apply_transformation(c2ws, cmat), cmat + return _apply_transformation(c2ws, cmat) + + +def intrinsics_to_fov( + fx: torch.Tensor, fy: torch.Tensor, h: torch.Tensor, w: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the horizontal and vertical fields of view in radians from camera intrinsics. + + Args: + fx (torch.Tensor): focal x + fy (torch.Tensor): focal y + h (torch.Tensor): Image height(s) with shape (B,). + w (torch.Tensor): Image width(s) with shape (B,). + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing the horizontal and vertical fields + of view in radians, both with shape (N,). + """ + return 2 * torch.atan((w / 2) / fx), 2 * torch.atan((h / 2) / fy) + + +def cv2gl( + c2ws: torch.Tensor | np.ndarray, + return_cmat: bool = False, +) -> torch.Tensor | np.ndarray | tuple[torch.Tensor | np.ndarray, np.ndarray]: + """ + Convert camera poses from OpenCV to OpenGL coordinate system. + + Args: + c2ws (torch.Tensor or np.ndarray): Camera poses (batch_size, 4, 4) or (4, 4) + return_cmat (bool): If True, return the conversion matrix along with the transformed poses + + Returns: + torch.Tensor or np.ndarray: Transformed camera poses (batch_size, 4, 4) or (4, 4) + np.ndarray (optional): Conversion matrix if return_cmat is True + """ + cmat = _gl_cv_cmat() + if return_cmat: + return _apply_transformation(c2ws, cmat), cmat + return _apply_transformation(c2ws, cmat) diff --git a/mapanything/utils/wai/colormaps/colors_fps_5k.npz b/mapanything/utils/wai/colormaps/colors_fps_5k.npz new file mode 100644 index 0000000000000000000000000000000000000000..7f259f39eb571a3c0848cc6a82d28a384e2254e6 --- /dev/null +++ b/mapanything/utils/wai/colormaps/colors_fps_5k.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fae94fe5fb565ff40d1c556ae2640d00fc068e732cb4af5bb64eef034790e07c +size 9478 diff --git a/mapanything/utils/wai/core.py b/mapanything/utils/wai/core.py new file mode 100644 index 0000000000000000000000000000000000000000..fa8b27f6184d1682c6b58abd433f61246e53bf9e --- /dev/null +++ b/mapanything/utils/wai/core.py @@ -0,0 +1,497 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +This utils script contains PORTAGE of wai-core core methods for MapAnything. +""" + +import logging +import re +from pathlib import Path +from typing import Any + +import numpy as np +import torch + +from mapanything.utils.wai.camera import ( + CAMERA_KEYS, + convert_camera_coeffs_to_pinhole_matrix, + interpolate_extrinsics, + interpolate_intrinsics, +) +from mapanything.utils.wai.io import _get_method, _load_scene_meta +from mapanything.utils.wai.ops import crop + +logger = logging.getLogger(__name__) + +WAI_COLORMAP_PATH = Path(__file__).parent / "colormaps" + + +def load_data(fname: str | Path, format_type: str | None = None, **kwargs) -> Any: + """ + Loads data from a file using the appropriate method based on the file format. + + Args: + fname (str or Path): The filename or path to load data from. + format_type (str, optional): The format type of the data. If None, it will be inferred from the file extension if possible. + Supported formats include: 'readable', 'scalar', 'image', 'binary', 'depth', 'normals', + 'numpy', 'ptz', 'mmap', 'scene_meta', 'labeled_image', 'mesh', 'labeled_mesh', 'caption', "latents". + **kwargs: Additional keyword arguments to pass to the loading method. + + Returns: + The loaded data in the format returned by the specific loading method. + + Raises: + ValueError: If the format cannot be inferred from the file extension. + NotImplementedError: If the specified format is not supported. + FileExistsError: If the file does not exist. + """ + load_method = _get_method(fname, format_type, load=True) + return load_method(fname, **kwargs) + + +def store_data( + fname: str | Path, + data: Any, + format_type: str | None = None, + **kwargs, +) -> Any: + """ + Stores data to a file using the appropriate method based on the file format. + + Args: + fname (str or Path): The filename or path to store data to. + data: The data to be stored. + format_type (str, optional): The format type of the data. If None, it will be inferred from the file extension. + **kwargs: Additional keyword arguments to pass to the storing method. + + Returns: + The result of the storing method, which may vary depending on the method used. + """ + store_method = _get_method(fname, format_type, load=False) + Path(fname).parent.mkdir(parents=True, exist_ok=True) + return store_method(fname, data, **kwargs) + + +def get_frame( + scene_meta: dict[str, Any], + frame_key: int | str | float, +) -> dict[str, Any]: + """ + Get a frame from scene_meta based on name or index. + + Args: + scene_meta: Dictionary containing scene metadata + frame_key: Either a string (frame name) or integer (frame index) or float (video timestamp) + + Returns: + The frame data (dict) + """ + frame_idx = get_frame_index(scene_meta, frame_key) + if isinstance(frame_idx, int): + frame = scene_meta["frames"][frame_idx] + frame["_is_interpolated"] = False + else: + frame = {} + frame["frame_name"] = frame_key + left = int(frame_idx) # it's floor operation + assert left >= 0 and left < (len(scene_meta["frames"]) - 1), "Wrong index" + frame_left = scene_meta["frames"][left] + frame_right = scene_meta["frames"][left + 1] + # Interpolate intrinsics and extrinsics + frame["transform_matrix"] = interpolate_extrinsics( + frame_left["transform_matrix"], + frame_right["transform_matrix"], + frame_idx - left, + ) + frame.update( + interpolate_intrinsics( + frame_left, + frame_right, + frame_idx - left, + ) + ) + frame["_is_interpolated"] = True + return frame + + +def get_intrinsics( + scene_meta, + frame_key, + fmt: str = "torch", +) -> torch.Tensor | np.ndarray | list: + frame = get_frame(scene_meta, frame_key) + return convert_camera_coeffs_to_pinhole_matrix(scene_meta, frame, fmt=fmt) + + +def get_extrinsics( + scene_meta, + frame_key, + fmt: str = "torch", +) -> torch.Tensor | np.ndarray | list | None: + frame = get_frame(scene_meta, frame_key) + if "transform_matrix" in frame: + if fmt == "torch": + return torch.tensor(frame["transform_matrix"]).reshape(4, 4).float() + elif fmt == "np": + return np.array(frame["transform_matrix"]).reshape(4, 4) + return frame["transform_matrix"] + else: + # TODO: should not happen if we enable interpolation + return None + + +def get_frame_index( + scene_meta: dict[str, Any], + frame_key: int | str | float, + frame_index_threshold_sec: float = 1e-4, + distance_threshold_sec: float = 2.0, +) -> int | float: + """ + Returns the frame index from scene_meta based on name (str) or index (int) or sub-frame index (float). + + Args: + scene_meta: Dictionary containing scene metadata + frame_key: Either a string (frame name) or integer (frame index) or float (sub-frame index) + frame_index_threshold_sec: A threshold for nearest neighbor clipping for indexes (in seconds). + Default is 1e-4, which is 10000 fps. + distance_th: A threshold for maximum distance between interpolated frames (in seconds). + + Returns: + Frame index (int) + + Raises: + ValueError: If frame_key is not a string or integer or float + """ + if isinstance(frame_key, str): + try: + return scene_meta["frame_names"][frame_key] + except KeyError as err: + error_message = ( + f"Frame name not found: {frame_key} - " + f"Please verify scene_meta.json of scene: {scene_meta['dataset_name']}/{scene_meta['scene_name']}" + ) + logger.error(error_message) + raise KeyError(error_message) from err + + if isinstance(frame_key, int): + return frame_key + + if isinstance(frame_key, float): + # If exact hit + if frame_key in scene_meta["frame_names"]: + return scene_meta["frame_names"][frame_key] + + frame_names = sorted(list(scene_meta["frame_names"].keys())) + distances = np.array([frm - frame_key for frm in frame_names]) + left = int(np.nonzero(distances <= 0)[0][-1]) + right = left + 1 + + # The last frame or rounding errors + if ( + left == distances.shape[0] - 1 + or abs(distances[left]) < frame_index_threshold_sec + ): + return scene_meta["frame_names"][frame_names[int(left)]] + if abs(distances[right]) < frame_index_threshold_sec: + return scene_meta["frame_names"][frame_names[int(right)]] + + interpolation_distance = distances[right] - distances[left] + if interpolation_distance > distance_threshold_sec: + raise ValueError( + f"Frame interpolation is forbidden for distances larger than {distance_threshold_sec}." + ) + alpha = -distances[left] / interpolation_distance + + return scene_meta["frame_names"][frame_names[int(left)]] + alpha + + raise ValueError(f"Frame key type not supported: {frame_key} ({type(frame_key)}).") + + +def load_modality_data( + scene_root: Path | str, + results: dict[str, Any], + modality_dict: dict[str, Any], + modality: str, + frame: dict[str, Any] | None = None, + fmt: str = "torch", +) -> dict[str, Any]: + """ + Processes a modality by loading data from a specified path and updating the results dictionary. + This function extracts the format and path from the given modality dictionary, loads the data + from the specified path, and updates the results dictionary with the loaded data. + + Args: + scene_root (str or Path): The root directory of the scene where the data is located. + results (dict): A dictionary to store the loaded modality data and optional frame path. + modality_dict (dict): A dictionary containing the modality information, including 'format' + and the path to the data. + modality (str): The key under which the loaded modality data will be stored in the results. + frame (dict, optional): A dictionary containing frame information. If provided, that means we are loading + frame modalities, otherwise it is scene modalities. + + Returns: + dict: The updated results dictionary containing the loaded modality data. + """ + modality_format = modality_dict["format"] + + # The modality is stored as a video + if "video" in modality_format: + assert isinstance(frame["frame_name"], float), "frame_name should be float" + video_file = None + if "chunks" in modality_dict: + video_list = modality_dict["chunks"] + # Get the correct chunk of the video + for video_chunk in video_list: + if video_chunk["start"] <= frame["frame_name"] <= video_chunk["end"]: + video_file = video_chunk + break + else: + # There is only one video (no chunks) + video_file = modality_dict + if "start" not in video_file: + video_file["start"] = 0 + if "end" not in video_file: + video_file["end"] = float("inf") + if not (video_file["start"] <= frame["frame_name"] <= video_file["end"]): + video_file = None + + # This timestamp is not available in any of the chunks + if video_file is None: + frame_name = frame["frame_name"] + logger.warning( + f"Modality {modality} ({modality_format}) is not available at time {frame_name}" + ) + return results + + # Load the modality from the video + loaded_modality = load_data( + Path(scene_root, video_file["file"]), + modality_format, + frame_key=frame["frame_name"] - video_file["start"], + ) + + if "bbox" in video_file: + loaded_modality = crop(loaded_modality, video_file["bbox"]) + + if loaded_modality is not None: + results[modality] = loaded_modality + + if frame: + results[f"{modality}_fname"] = video_file["file"] + else: + modality_path = [v for k, v in modality_dict.items() if k != "format"][0] + if frame: + if modality_path in frame: + fname = frame[modality_path] + else: + fname = None + else: + fname = modality_path + if fname is not None: + loaded_modality = load_data( + Path(scene_root, fname), + modality_format, + frame_key=frame["frame_name"] if frame else None, + fmt=fmt, + ) + results[modality] = loaded_modality + if frame: + results[f"{modality}_fname"] = frame[modality_path] + return results + + +def load_modality( + scene_root: Path | str, + modality_meta: dict[str, Any], + modality: str, + frame: dict[str, Any] | None = None, + fmt: str = "torch", +) -> dict[str, Any]: + """ + Loads modality data based on the provided metadata and updates the results dictionary. + This function navigates through the modality metadata to find the specified modality, + then loads the data for each modality found. + + Args: + scene_root (str or Path): The root directory of the scene where the data is located. + modality_meta (dict): A nested dictionary containing metadata for various modalities. + modality (str): A string representing the path to the desired modality within the metadata, + using '/' as a separator for nested keys. + frame (dict, optional): A dictionary containing frame information. If provided, we are operating + on frame modalities, otherwise it is scene modalities. + + Returns: + dict: A dictionary containing the loaded modality data. + """ + results = {} + # support for nested modalities like "pred_depth/metric3dv2" + modality_keys = modality.split("/") + current_modality = modality_meta + for key in modality_keys: + try: + current_modality = current_modality[key] + except KeyError as err: + error_message = ( + f"Modality '{err.args[0]}' not found in modalities metadata. " + f"Please verify the scene_meta.json and the provided modalities in {scene_root}." + ) + logger.error(error_message) + raise KeyError(error_message) from err + if "format" in current_modality: + results = load_modality_data( + scene_root, results, current_modality, modality, frame, fmt=fmt + ) + else: + # nested modality, return last by default + logger.warning("Nested modality, returning last by default") + key = next(reversed(current_modality.keys())) + results = load_modality_data( + scene_root, results, current_modality[key], modality, frame, fmt=fmt + ) + return results + + +def load_frame( + scene_root: Path | str, + frame_key: int | str | float, + modalities: str | list[str] | None = None, + scene_meta: dict[str, Any] | None = None, + load_intrinsics: bool = True, + load_extrinsics: bool = True, + fmt: str = "torch", + interpolate: bool = False, +) -> dict[str, Any]: + """ + Load a single frame from a scene with specified modalities. + + Args: + scene_root (str or Path): The root directory of the scene where the data is located. + frame_key (int or str or float): Either a string (frame name) or integer (frame index) or float (video timestamp). + modalities (str or list[str], optional): The modality or list of modalities to load. + If None, only basic frame information is loaded. + scene_meta (dict, optional): Dictionary containing scene metadata. If None, it will be loaded + from scene_meta.json in the scene_root. + interpolate (bool, optional): Allow interpolating frames? + + Returns: + dict: A dictionary containing the loaded frame data with the requested modalities. + """ + scene_root = Path(scene_root) + if scene_meta is None: + scene_meta = _load_scene_meta(scene_root / "scene_meta.json") + frame = get_frame(scene_meta, frame_key) + # compact, standardized frame representation + wai_frame = {} + if load_extrinsics: + extrinsics = get_extrinsics( + scene_meta, + frame_key, + fmt=fmt, + ) + if extrinsics is not None: + wai_frame["extrinsics"] = extrinsics + if load_intrinsics: + camera_model = frame.get("camera_model", scene_meta.get("camera_model")) + wai_frame["camera_model"] = camera_model + if camera_model == "PINHOLE": + wai_frame["intrinsics"] = get_intrinsics(scene_meta, frame_key, fmt=fmt) + elif camera_model in ["OPENCV", "OPENCV_FISHEYE"]: + # optional per-frame intrinsics + for camera_key in CAMERA_KEYS: + if camera_key in frame: + wai_frame[camera_key] = float(frame[camera_key]) + elif camera_key in scene_meta: + wai_frame[camera_key] = float(scene_meta[camera_key]) + else: + error_message = ( + f"Camera model not supported: {camera_model} - " + f"Please verify scene_meta.json of scene: {scene_meta['dataset_name']}/{scene_meta['scene_name']}" + ) + logger.error(error_message) + raise NotImplementedError(error_message) + wai_frame["w"] = frame.get("w", scene_meta["w"] if "w" in scene_meta else None) + wai_frame["h"] = frame.get("h", scene_meta["h"] if "h" in scene_meta else None) + wai_frame["frame_name"] = frame["frame_name"] + wai_frame["frame_idx"] = get_frame_index(scene_meta, frame_key) + wai_frame["_is_interpolated"] = frame["_is_interpolated"] + + if modalities is not None: + if isinstance(modalities, str): + modalities = [modalities] + for modality in modalities: + # Handle regex patterns in modality + if any(char in modality for char in ".|*+?()[]{}^$\\"): + # This is a regex pattern + pattern = re.compile(modality) + matching_modalities = [ + m for m in scene_meta["frame_modalities"] if pattern.match(m) + ] + if not matching_modalities: + raise ValueError( + f"No modalities match the pattern: {modality} in scene: {scene_root}" + ) + # Use the first matching modality + modality = matching_modalities[0] + current_modalities = load_modality( + scene_root, scene_meta["frame_modalities"], modality, frame, fmt=fmt + ) + wai_frame.update(current_modalities) + + return wai_frame + + +def set_frame( + scene_meta: dict[str, Any], + frame_key: int | str, + new_frame: dict[str, Any], + sort: bool = False, +) -> dict[str, Any]: + """ + Replace a frame in scene_meta with a new frame. + + Args: + scene_meta: Dictionary containing scene metadata. + frame_key: Either a string (frame name) or integer (frame index). + new_frame: New frame data to replace the existing frame. + sort: If True, sort the keys in the new_frame dictionary. + + Returns: + Updated scene_meta dictionary. + """ + frame_idx = get_frame_index(scene_meta, frame_key) + if isinstance(frame_idx, float): + raise ValueError( + f"Setting frame for sub-frame frame_key is not supported: {frame_key} ({type(frame_key)})." + ) + if sort: + new_frame = {k: new_frame[k] for k in sorted(new_frame)} + scene_meta["frames"][frame_idx] = new_frame + return scene_meta + + +def nest_modality( + frame_modalities: dict[str, Any], + modality_name: str, +) -> dict[str, Any]: + """ + Converts a flat modality structure into a nested one based on the modality name. + + Args: + frame_modalities (dict): Dictionary containing frame modalities. + modality_name (str): The name of the modality to nest. + + Returns: + dict: A dictionary with the nested modality structure. + """ + frame_modality = {} + if modality_name in frame_modalities: + frame_modality = frame_modalities[modality_name] + if "frame_key" in frame_modality: + # required for backwards compatibility + # converting non-nested format into nested one based on name + modality_name = frame_modality["frame_key"].split("_")[0] + frame_modality = {modality_name: frame_modality} + return frame_modality diff --git a/mapanything/utils/wai/intersection_check.py b/mapanything/utils/wai/intersection_check.py new file mode 100644 index 0000000000000000000000000000000000000000..cb7d1b088feae874d08a20108c369cf3e4902cf4 --- /dev/null +++ b/mapanything/utils/wai/intersection_check.py @@ -0,0 +1,467 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +from einops import rearrange, repeat +from tqdm import tqdm + + +def create_frustum_from_intrinsics( + intrinsics: torch.Tensor, + near: torch.Tensor | float, + far: torch.Tensor | float, +) -> torch.Tensor: + r""" + Create a frustum from camera intrinsics. + + Args: + intrinsics (torch.Tensor): Bx3x3 Intrinsics of cameras. + near (torch.Tensor or float): [B] Near plane distance. + far (torch.Tensor or float): [B] Far plane distance. + + Returns: + frustum (torch.Tensor): Bx8x3 batch of frustum points following the order: + 5 ---------- 4 + |\ /| + 6 \ / 7 + \ 1 ---- 0 / + \| |/ + 2 ---- 3 + """ + + fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1] + cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2] + + # Calculate the offsets at the near plane + near_x = near * (cx / fx) + near_y = near * (cy / fy) + far_x = far * (cx / fx) + far_y = far * (cy / fy) + + # Define frustum vertices in camera space + near_plane = torch.stack( + [ + torch.stack([near_x, near_y, near * torch.ones_like(near_x)], dim=-1), + torch.stack([-near_x, near_y, near * torch.ones_like(near_x)], dim=-1), + torch.stack([-near_x, -near_y, near * torch.ones_like(near_x)], dim=-1), + torch.stack([near_x, -near_y, near * torch.ones_like(near_x)], dim=-1), + ], + dim=1, + ) + + far_plane = torch.stack( + [ + torch.stack([far_x, far_y, far * torch.ones_like(far_x)], dim=-1), + torch.stack([-far_x, far_y, far * torch.ones_like(far_x)], dim=-1), + torch.stack([-far_x, -far_y, far * torch.ones_like(far_x)], dim=-1), + torch.stack([far_x, -far_y, far * torch.ones_like(far_x)], dim=-1), + ], + dim=1, + ) + + return torch.cat([near_plane, far_plane], dim=1) + + +def _frustum_to_triangles(frustum: torch.Tensor) -> torch.Tensor: + """ + Convert frustum to triangles. + + Args: + frustums (torch.Tensor): Bx8 batch of frustum points. + + Returns: + frustum_triangles (torch.Tensor): Bx3x3 batch of frustum triangles. + """ + + triangle_inds = torch.tensor( + [ + [0, 1, 2], + [0, 2, 3], + [0, 3, 7], + [0, 7, 4], + [1, 2, 6], + [1, 6, 5], + [1, 4, 5], + [1, 0, 4], + [2, 6, 7], + [2, 3, 7], + [6, 7, 4], + [6, 5, 4], + ] + ) + frustum_triangles = frustum[:, triangle_inds] + return frustum_triangles + + +def segment_triangle_intersection_check( + start_points: torch.Tensor, + end_points: torch.Tensor, + triangles: torch.Tensor, +) -> torch.Tensor: + """ + Check if segments (lines with starting and end point) intersect triangles in 3D using the + Moller-Trumbore algorithm. + + Args: + start_points (torch.Tensor): Bx3 Starting points of the segment. + end_points (torch.Tensor): Bx3 End points of the segment. + triangles (torch.Tensor): Bx3x3 Vertices of the triangles. + + Returns: + intersects (torch.Tensor): B Boolean tensor indicating if each ray intersects its + corresponding triangle. + """ + vertex0 = triangles[:, 0, :] + vertex1 = triangles[:, 1, :] + vertex2 = triangles[:, 2, :] + edge1 = vertex1 - vertex0 + edge2 = vertex2 - vertex0 + ray_vectors = end_points - start_points + max_lengths = torch.norm(ray_vectors, dim=1) + ray_vectors = ray_vectors / max_lengths[:, None] + h = torch.cross(ray_vectors, edge2, dim=1) + a = (edge1 * h).sum(dim=1) + + epsilon = 1e-6 + mask = torch.abs(a) > epsilon + f = torch.zeros_like(a) + f[mask] = 1.0 / a[mask] + + s = start_points - vertex0 + u = f * (s * h).sum(dim=1) + q = torch.cross(s, edge1, dim=1) + v = f * (ray_vectors * q).sum(dim=1) + + t = f * (edge2 * q).sum(dim=1) + + # Check conditions + intersects = ( + (u >= 0) + & (u <= 1) + & (v >= 0) + & (u + v <= 1) + & (t >= epsilon) + & (t <= max_lengths) + ) + + return intersects + + +def triangle_intersection_check( + triangles1: torch.Tensor, + triangles2: torch.Tensor, +) -> torch.Tensor: + """ + Check if two triangles intersect. + + Args: + triangles1 (torch.Tensor): Bx3x3 Vertices of the first batch of triangles. + triangles2 (torch.Tensor): Bx3x3 Vertices of the first batch of triangles. + + Returns: + triangle_intersection (torch.Tensor): B Boolean tensor indicating if triangles intersect. + """ + n = triangles1.shape[1] + start_points1 = rearrange(triangles1, "B N C -> (B N) C") + end_points1 = rearrange( + triangles1[:, torch.arange(1, n + 1) % n], "B N C -> (B N) C" + ) + + start_points2 = rearrange(triangles2, "B N C -> (B N) C") + end_points2 = rearrange( + triangles2[:, torch.arange(1, n + 1) % n], "B N C -> (B N) C" + ) + intersection_1_2 = segment_triangle_intersection_check( + start_points1, end_points1, repeat(triangles2, "B N C -> (B N2) N C", N2=3) + ) + intersection_2_1 = segment_triangle_intersection_check( + start_points2, end_points2, repeat(triangles1, "B N C -> (B N2) N C", N2=3) + ) + triangle_intersection = torch.any( + rearrange(intersection_1_2, "(B N N2) -> B (N N2)", B=triangles1.shape[0], N=n), + dim=1, + ) | torch.any( + rearrange(intersection_2_1, "(B N N2) -> B (N N2)", B=triangles1.shape[0], N=n), + dim=1, + ) + return triangle_intersection + + +def frustum_intersection_check( + frustums: torch.Tensor, + check_inside: bool = True, + chunk_size: int = 500, + device: str | None = None, +) -> torch.Tensor: + """ + Check if any pair of the frustums intersect with each other. + + Args: + frustums (torch.Tensor): Bx8 batch of frustum points. + check_inside (bool): If True, also checks if one frustum is inside another. + Defaults to True. + chunk_size (Optional[int]): Number of chunks to split the computation into. + Defaults to 500. + device (Optional[str]): Device to store exhaustive frustum intersection matrix on. + Defaults to None. + + Returns: + frustum_intersection (torch.Tensor): BxB tensor of Booleans indicating if any pair + of frustums intersect with each other. + """ + B = frustums.shape[0] + if device is None: + device = frustums.device + frustum_triangles = _frustum_to_triangles(frustums) + T = frustum_triangles.shape[1] + + # Perform frustum in frustum check if required + if check_inside: + frustum_intersection = frustums_in_frustum_check( + frustums=frustums, chunk_size=chunk_size, device=device + ) + else: + frustum_intersection = torch.zeros((B, B), dtype=torch.bool, device=device) + + # Check triangle intersections in chunks + for i in tqdm(range(0, B, chunk_size), desc="Checking triangle intersections"): + i_end = min(i + chunk_size, B) + chunk_i_size = i_end - i + + for j in range(0, B, chunk_size): + j_end = min(j + chunk_size, B) + chunk_j_size = j_end - j + + # Process all triangle pairs between the two chunks in a vectorized way + triangles_i = frustum_triangles[i:i_end] # [chunk_i, T, 3, 3] + triangles_j = frustum_triangles[j:j_end] # [chunk_j, T, 3, 3] + + # Reshape to process all triangle pairs at once + tri_i = triangles_i.reshape(chunk_i_size * T, 3, 3) + tri_j = triangles_j.reshape(chunk_j_size * T, 3, 3) + + # Expand for all pairs - explicitly specify dimensions instead of using ... + tri_i_exp = repeat(tri_i, "bt i j -> (bt bj_t) i j", bj_t=chunk_j_size * T) + tri_j_exp = repeat(tri_j, "bt i j -> (bi_t bt) i j", bi_t=chunk_i_size * T) + + # Check intersection + batch_intersect = triangle_intersection_check(tri_i_exp, tri_j_exp) + + # Reshape and check if any triangle pair intersects + batch_intersect = batch_intersect.reshape(chunk_i_size, T, chunk_j_size, T) + batch_intersect = batch_intersect.any(dim=(1, 3)) + + # Update result + frustum_intersection[i:i_end, j:j_end] |= batch_intersect.to(device) + + return frustum_intersection + + +def ray_triangle_intersection_check( + ray_origins: torch.Tensor, + ray_vectors: torch.Tensor, + triangles: torch.Tensor, + max_lengths: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Check if rays intersect triangles in 3D using the Moller-Trumbore algorithm, considering the + finite length of rays. + + Args: + ray_origins (torch.Tensor): Bx3 Origins of the rays. + ray_vectors (torch.Tensor): Bx3 Direction vectors of the rays. + triangles (torch.Tensor): Bx3x3 Vertices of the triangles. + max_lengths Optional[torch.Tensor]: B Maximum lengths of the rays. + + Returns: + intersects (torch.Tensor): B Boolean tensor indicating if each ray intersects its + corresponding triangle. + """ + vertex0 = triangles[:, 0, :] + vertex1 = triangles[:, 1, :] + vertex2 = triangles[:, 2, :] + edge1 = vertex1 - vertex0 + edge2 = vertex2 - vertex0 + h = torch.cross(ray_vectors, edge2, dim=1) + a = (edge1 * h).sum(dim=1) + + epsilon = 1e-6 + mask = torch.abs(a) > epsilon + f = torch.zeros_like(a) + f[mask] = 1.0 / a[mask] + + s = ray_origins - vertex0 + u = f * (s * h).sum(dim=1) + q = torch.cross(s, edge1, dim=1) + v = f * (ray_vectors * q).sum(dim=1) + + t = f * (edge2 * q).sum(dim=1) + + # Check conditions + intersects = (u >= 0) & (u <= 1) & (v >= 0) & (u + v <= 1) & (t >= epsilon) + if max_lengths is not None: + intersects &= t <= max_lengths + + return intersects + + +#### Checks for frustums +def _frustum_to_planes(frustums: torch.Tensor) -> torch.Tensor: + r""" + Converts frustum parameters to plane representation. + + Args: + frustums (torch.Tensor): Bx8 batch of frustum points following the order: + 5 ---------- 4 + |\ /| + 6 \ / 7 + \ 1 ---- 0 / + \| |/ + 2 ---- 3 + + Returns: + planes (torch.Tensor): Bx6x4 where 6 represents the six frustum planes and + 4 represents plane parameters [a, b, c, d]. + """ + planes = [] + for inds in [[0, 1, 3], [1, 6, 2], [0, 3, 7], [2, 6, 3], [0, 5, 1], [6, 5, 4]]: + normal = torch.cross( + frustums[:, inds[1]] - frustums[:, inds[0]], + frustums[:, inds[2]] - frustums[:, inds[0]], + dim=1, + ) + normal = normal / torch.norm(normal, dim=1, keepdim=True) + d = -torch.sum(normal * frustums[:, inds[0]], dim=1, keepdim=True) + planes.append(torch.cat([normal, d], -1)) + return torch.stack(planes, 1) + + +def points_in_frustum_check( + frustums: torch.Tensor, + points: torch.Tensor, + chunk_size: int | None = None, + device: str | None = None, +): + """ + Check if points are inside frustums. + + Args: + frustums (torch.Tensor): Bx8 batch of frustum points. + points (torch.Tensor): BxNx3 batch of points. + chunk_size (Optional[int]): Number of chunks to split the computation into. Defaults to None. + device (Optional[str]): Device to perform computation on. Defaults to None. + + Returns: + inside (torch.Tensor): BxN batch of Booleans indicating if points are inside frustums. + """ + if device is None: + device = frustums.device + + if chunk_size is not None: + # Split computation into chunks to avoid OOM errors for large batch sizes + point_plane_direction = [] + for chunk_idx in range(0, frustums.shape[0], chunk_size): + chunk_frustum_planes = _frustum_to_planes( + frustums[chunk_idx : chunk_idx + chunk_size] + ) + # Bx8x4 tensor of plane parameters [a, b, c, d] + chunk_points = points[chunk_idx : chunk_idx + chunk_size] + chunk_point_plane_direction = torch.einsum( + "bij,bnj->bni", (chunk_frustum_planes[:, :, :-1], chunk_points) + ) + repeat( + chunk_frustum_planes[:, :, -1], "B P -> B N P", N=chunk_points.shape[1] + ) # BxMxN tensor + point_plane_direction.append(chunk_point_plane_direction.to(device)) + point_plane_direction = torch.cat(point_plane_direction) + else: + # Convert frustums to planes + frustum_planes = _frustum_to_planes( + frustums + ) # Bx8x4 tensor of plane parameters [a, b, c, d] + # Compute dot product between each point and each plane + point_plane_direction = torch.einsum( + "bij,bnj->bni", (frustum_planes[:, :, :-1], points) + ) + repeat(frustum_planes[:, :, -1], "B P -> B N P", N=points.shape[1]).to( + device + ) # BxMxN tensor + + inside = (point_plane_direction >= 0).all(-1) + return inside + + +def frustums_in_frustum_check( + frustums: torch.Tensor, + chunk_size: int, + device: str | None = None, + use_double_chunking: bool = True, +): + """ + Check if frustums are contained within other frustums. + + Args: + frustums (torch.Tensor): Bx8 batch of frustum points. + chunk_size (Optional[int]): Number of chunks to split the computation into. + Defaults to None. + device (Optional[str]): Device to store exhaustive frustum containment matrix on. + Defaults to None. + use_double_chunking (bool): If True, use double chunking to avoid OOM errors. + Defaults to True. + + Returns: + frustum_contained (torch.Tensor): BxB batch of Booleans indicating if frustums are inside + other frustums. + """ + B = frustums.shape[0] + if device is None: + device = frustums.device + + if use_double_chunking: + frustum_contained = torch.zeros((B, B), dtype=torch.bool, device=device) + # Check if frustums are containing each other by processing in chunks + for i in tqdm(range(0, B, chunk_size), desc="Checking frustum containment"): + i_end = min(i + chunk_size, B) + chunk_i_size = i_end - i + + for j in range(0, B, chunk_size): + j_end = min(j + chunk_size, B) + chunk_j_size = j_end - j + + # Process a chunk of frustums against another chunk + frustums_i = frustums[i:i_end] + frustums_j_vertices = frustums[ + j:j_end, :1 + ] # Just need one vertex to check containment + + # Perform points in frustum check + contained = rearrange( + points_in_frustum_check( + repeat(frustums_i, "B ... -> (B B2) ...", B2=chunk_j_size), + repeat( + frustums_j_vertices, "B ... -> (B2 B) ...", B2=chunk_i_size + ), + )[:, 0], + "(B B2) -> B B2", + B=chunk_i_size, + ).to(device) + + # Map results back to the full matrix + frustum_contained[i:i_end, j:j_end] |= contained + frustum_contained[j:j_end, i:i_end] |= contained.transpose( + 0, 1 + ) # Symmetric relation + else: + # Perform points in frustum check with a single chunked loop + frustum_contained = rearrange( + points_in_frustum_check( + repeat(frustums, "B ... -> (B B2) ...", B2=B), + repeat(frustums[:, :1], "B ... -> (B2 B) ...", B2=B), + chunk_size=chunk_size, + )[:, 0], + "(B B2) -> B B2", + B=B, + ).to(device) + frustum_contained = frustum_contained | frustum_contained.T + + return frustum_contained diff --git a/mapanything/utils/wai/io.py b/mapanything/utils/wai/io.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4e0a8e1640a92e9b89bf50ca65111440d4a5f6 --- /dev/null +++ b/mapanything/utils/wai/io.py @@ -0,0 +1,1378 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +This utils script contains PORTAGE of wai-core io methods for MapAnything. +""" + +import gzip +import io +import json +import logging +import os +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, cast, IO, Literal, overload + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 +import numpy as np +import torch +import trimesh +import yaml +from PIL import Image, PngImagePlugin +from plyfile import PlyData, PlyElement +from safetensors.torch import load_file as load_sft, save_file as save_sft +from torchvision.io import decode_image +from yaml import CLoader + +from mapanything.utils.wai.ops import ( + to_numpy, +) +from mapanything.utils.wai.semantics import ( + apply_id_to_color_mapping, + INVALID_ID, + load_semantic_color_mapping, +) + +# Try to use orjson for faster JSON processing +try: + import orjson +except ImportError: + orjson = None + +logger = logging.getLogger(__name__) + + +@overload +def _load_readable( + fname: Path | str, load_as_string: Literal[True], **kwargs +) -> str: ... +@overload +def _load_readable( + fname: Path | str, load_as_string: Literal[False] = False, **kwargs +) -> dict: ... + + +def _load_readable( + fname: Path | str, + load_as_string: bool = False, + **kwargs, +) -> Any | str: + """ + Loads data from a human-readable file and will try to parse JSON or YAML files as a dict, list, + int, float, str, bool, or None object. Can optionally return the file contents as a string. + + Args: + fname (str or Path): The filename to load data from. + load_as_string (bool, optional): Whether to return the loaded data as a string. + Defaults to False. + + Returns: + The loaded data, which can be any type of object that can be represented in JSON or YAML. + + Raises: + NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml). + """ + if load_as_string: + return _load_readable_string(fname, **kwargs) + else: + return _load_readable_structured(fname, **kwargs) + + +def _load_readable_structured( + fname: Path | str, + **kwargs, +) -> Any: + """ + Loads data from a human-readable file and will try to parse JSON or YAML files as a dict, list, + int, float, str, bool, or None object. + + Args: + fname (str or Path): The filename to load data from. + + Returns: + The loaded data, which can be any type of object that can be represented in JSON or YAML. + + Raises: + NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml). + """ + fname = Path(fname) + if not fname.exists(): + raise FileNotFoundError(f"File does not exist: {fname}") + + if fname.suffix == ".json": + # Use binary mode for JSON files + with open(fname, mode="rb") as f: + # Use orjson if available, otherwise use standard JSON + if orjson: + return orjson.loads(f.read()) + return json.load(f) + + if fname.suffix in [".yaml", ".yml"]: + # Use text mode with UTF-8 encoding for YAML files + with open(fname, mode="r", encoding="utf-8") as f: + return yaml.load(f, Loader=CLoader) + + raise NotImplementedError(f"Readable format not supported: {fname.suffix}") + + +def _load_readable_string( + fname: Path | str, + **kwargs, +) -> str: + """ + Loads data from a human-readable file as a string. + + Args: + fname (str or Path): The filename to load data from. + + Returns: + The file's contents, as a string. + """ + fname = Path(fname) + if not fname.exists(): + raise FileNotFoundError(f"File does not exist: {fname}") + + with open(fname, mode="r", encoding="utf-8") as f: + contents = f.read() + + return contents + + +def _store_readable( + fname: Path | str, + data: Any, + **kwargs, +) -> int: + """ + Stores data in a human-readable file (JSON or YAML). + + Args: + fname (str or Path): The filename to store data in. + data: The data to store, which can be any type of object that can be represented in JSON or YAML. + + Returns: + The number of bytes written to the file. + + Raises: + NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml). + """ + fname = Path(fname) + + # Create parent directory if it doesn't exist + os.makedirs(fname.parent, exist_ok=True) + + if fname.suffix == ".json": + if orjson: + # Define the operation for orjson + with open(fname, mode="wb") as f: + return f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2)) + else: + # Define the operation for standard json + with open(fname, mode="w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + return f.tell() + + elif fname.suffix in [".yaml", ".yml"]: + # Define the operation for YAML files + with open(fname, mode="w", encoding="utf-8") as f: + yaml.dump(data, f) + return f.tell() + else: + raise NotImplementedError(f"Writable format not supported: {fname.suffix}") + + +def get_processing_state(scene_root: Path | str) -> dict: + """ + Retrieves the processing state of a scene. + + Args: + scene_root (Path or str): The root directory of the scene. + + Returns: + dict: A dictionary containing the processing state of the scene. + If no processing log exists, or reading it fails, an empty + dictionary is returned. + """ + process_log_path = Path(scene_root) / "_process_log.json" + + try: + return _load_readable_structured(process_log_path) + except FileNotFoundError: + logger.debug(f"Log file not found, returning empty dict: {process_log_path}") + return {} + except Exception: + logger.error( + f"Could not parse, returning empty dict: {process_log_path}", exc_info=True + ) + return {} + + +def _write_exr( + fname: str | Path, + data: np.ndarray | torch.Tensor, + params: list | None = None, + **kwargs, +) -> bool: + """ + Writes an image as an EXR file using OpenCV. + + Args: + fname (str or Path): The filename to save the image to. + data (numpy.ndarray, torch.Tensor): The image data to save. Must be a 2D or 3D array. + params (list, optional): A list of parameters to pass to OpenCV's imwrite function. + Defaults to None, which uses 32-bit with zip compression. + + Returns: + bool: True if the image was saved successfully, False otherwise. + + Raises: + ValueError: If the input data has less than two or more than three dimensions. + + Notes: + Only 32-bit float (CV_32F) images can be saved. + For comparison of different compression methods, see P1732924327. + """ + if Path(fname).suffix != ".exr": + raise ValueError( + f"Only filenames with suffix .exr allowed but received: {fname}" + ) + + ## Note: only 32-bit float (CV_32F) images can be saved + data_np = to_numpy(data, dtype=np.float32) + if (data_np.ndim > 3) or (data_np.ndim < 2): + raise ValueError( + f"Image needs to contain two or three dims but received: {data_np.shape}" + ) + + return cv2.imwrite(str(fname), data_np, params if params else []) + + +@overload +def _read_exr(fname: str | Path, fmt: Literal["np"], **kwargs) -> np.ndarray: ... +@overload +def _read_exr(fname: str | Path, fmt: Literal["PIL"], **kwargs) -> Image.Image: ... +@overload +def _read_exr( + fname: str | Path, fmt: Literal["torch"] = "torch", **kwargs +) -> torch.Tensor: ... + + +def _read_exr( + fname: str | Path, fmt: Literal["np", "PIL", "torch"] = "torch", **kwargs +) -> np.ndarray | torch.Tensor | Image.Image: + """ + Reads an EXR image file using OpenCV. + + Args: + fname (str or Path): The filename of the EXR image to read. + fmt (str): The format of the output data. Can be one of: + - "torch": Returns a PyTorch tensor. + - "np": Returns a NumPy array. + - "PIL": Returns a PIL Image object. + Defaults to "torch". + + Returns: + The EXR image data in the specified output format. + + Raises: + NotImplementedError: If the specified output format is not supported. + ValueError: If data shape is not supported, e.g. multi-channel PIL float images. + + Notes: + The EXR image is read in its original format, without any conversion or rescaling. + """ + data = cv2.imread(str(fname), cv2.IMREAD_UNCHANGED) + if data is None: + raise FileNotFoundError(f"Failed to read EXR file: {fname}") + if fmt == "torch": + # Convert to PyTorch tensor with float32 dtype + data = torch.from_numpy(data).float() + elif fmt == "np": + # Convert to NumPy array with float32 dtype + data = np.array(data, dtype=np.float32) + elif fmt == "PIL": + if data.ndim != 2: + raise ValueError("PIL does not support multi-channel EXR images") + + # Convert to PIL Image object + data = Image.fromarray(data) + else: + raise NotImplementedError(f"fmt not supported: {fmt}") + return data + + +@overload +def _load_image( + fname: str | Path, + fmt: Literal["np"], + resize: tuple[int, int] | None = None, + **kwargs, +) -> np.ndarray: ... +@overload +def _load_image( + fname: str | Path, + fmt: Literal["pil"], + resize: tuple[int, int] | None = None, + **kwargs, +) -> Image.Image: ... +@overload +def _load_image( + fname: str | Path, + fmt: Literal["torch"] = "torch", + resize: tuple[int, int] | None = None, + **kwargs, +) -> torch.Tensor: ... + + +def _load_image( + fname: str | Path, + fmt: Literal["np", "pil", "torch"] = "torch", + resize: tuple[int, int] | None = None, + **kwargs, +) -> np.ndarray | torch.Tensor | Image.Image: + """ + Loads an image from a file. + + Args: + fname (str or Path): The filename to load the image from. + fmt (str): The format of the output data. Can be one of: + - "torch": Returns a PyTorch tensor with shape (C, H, W). + - "np": Returns a NumPy array with shape (H, W, C). + - "pil": Returns a PIL Image object. + Defaults to "torch". + resize (tuple, optional): A tuple of two integers representing the desired width and height of the image. + If None, the image is not resized. Defaults to None. + + Returns: + The loaded image in the specified output format. + + Raises: + NotImplementedError: If the specified output format is not supported. + + Notes: + This function loads non-binary images in RGB mode and normalizes pixel values to the range [0, 1]. + """ + + # Fastest way to load into torch tensor + if resize is None and fmt == "torch": + return decode_image(str(fname)).float() / 255.0 + + # Load using PIL + with open(fname, "rb") as f: + pil_image = Image.open(f) + pil_image.load() + + if pil_image.mode not in ["RGB", "RGBA"]: + raise OSError( + f"Expected a RGB or RGBA image in {fname}, but instead found an image with mode {pil_image.mode}" + ) + + if resize is not None: + pil_image = pil_image.resize(resize) + + if fmt == "torch": + return ( + torch.from_numpy(np.array(pil_image)).permute(2, 0, 1).float() / 255.0 + ) + elif fmt == "np": + return np.array(pil_image, dtype=np.float32) / 255.0 + elif fmt == "pil": + return pil_image + else: + raise NotImplementedError(f"Image format not supported: {fmt}") + + +def _store_image( + fname: str | Path, img_data: np.ndarray | torch.Tensor | Image.Image, **kwargs +) -> None: + """ + Stores an image in a file. + + Args: + fname (str or Path): The filename to store the image in. + img_data (numpy.ndarray, torch.tensor or PIL.Image.Image): The image data to store. + + Notes (for numpy.ndarray or torch.tensor inputs): + This function assumes that the input image data is in the range [0, 1], and has shape + (H, W, C), or (C, H, W) for PyTorch tensors, with C being 3 or 4. + It converts the image data to uint8 format and saves it as a compressed image file. + """ + if isinstance(img_data, torch.Tensor): + if img_data.ndim != 3: + raise ValueError(f"Tensor needs to be 3D but received: {img_data.shape=}") + + if img_data.shape[0] in [3, 4]: + # Convert to HWC format expected by pillow `Image.save` below + img_data = img_data.permute(1, 2, 0) + + img_data = img_data.contiguous() + + if isinstance(img_data, (np.ndarray, torch.Tensor)): + if img_data.shape[-1] not in [3, 4]: + raise ValueError( + f"Image must have 3 or 4 channels, but received: {img_data.shape=}" + ) + + img_data_np = to_numpy(img_data, dtype=np.float32) + img_data = Image.fromarray((255 * img_data_np).round().astype(np.uint8)) + + with open(fname, "wb") as f: + pil_kwargs = { + # Make PNGs faster to save using minimal compression + "optimize": False, + "compress_level": 1, + # Higher JPEG image quality + "quality": "high", + } + pil_kwargs.update(kwargs) + img_data.save(cast(IO[bytes], f), **pil_kwargs) + + +def _load_binary_mask( + fname: str | Path, + fmt: str = "torch", + resize: tuple[int, int] | None = None, + **kwargs, +) -> np.ndarray | torch.Tensor | Image.Image: + """ + Loads a binary image from a file. + + Args: + fname (str or Path): The filename to load the binary image from. + fmt (str): The format of the output data. Can be one of: + - "torch": Returns a PyTorch Boolean tensor with shape H x W. + - "np": Returns a NumPy Boolean array with shape H x W. + - "pil": Returns a PIL Image object. + Defaults to "torch". + resize (tuple, optional): A tuple of two integers representing the desired width and height of the binary image. + If None, the image is not resized. Defaults to None. + + Returns: + The loaded binary image in the specified output format. + + Raises: + NotImplementedError: If the specified output format is not supported. + """ + if fmt not in ["pil", "np", "torch"]: + raise NotImplementedError(f"Image format not supported: {fmt}") + + with open(fname, "rb") as f: + pil_image = Image.open(f) + pil_image.load() + + if pil_image.mode == "L": + pil_image = pil_image.convert("1") + + elif pil_image.mode != "1": + raise OSError( + f"Expected a binary or grayscale image in {fname}, but instead found an image with mode {pil_image.mode}" + ) + + if resize is not None: + pil_image = pil_image.resize(resize) + + if fmt == "pil": + return pil_image + + mask = np.array(pil_image, copy=True) + return mask if fmt == "np" else torch.from_numpy(mask) + + +def _store_binary_mask( + fname: str | Path, img_data: np.ndarray | torch.Tensor | Image.Image, **kwargs +) -> None: + """ + Stores a binary image in a compressed image file. + + Args: + fname (str or Path): The filename to store the binary image in. + img_data (numpy.ndarray, torch.tensor or PIL.Image.Image): The binary image data to store. + """ + if isinstance(img_data, Image.Image): + if img_data.mode not in ["1", "L"]: + raise RuntimeError( + f'Expected a PIL image with mode "1" or "L", but instead got a PIL image with mode {img_data.mode}' + ) + elif isinstance(img_data, np.ndarray) or isinstance(img_data, torch.Tensor): + if len(img_data.squeeze().shape) != 2: + raise RuntimeError( + f"Expected a PyTorch tensor or NumPy array with shape (H, W, 1), (1, H, W) or (H, W), but the shape is {img_data.shape}" + ) + img_data = img_data.squeeze() + else: + raise NotImplementedError(f"Input format not supported: {type(img_data)}") + + if not isinstance(img_data, Image.Image): + img_data = to_numpy(img_data, dtype=bool) + img_data = Image.fromarray(img_data) + + img_data = img_data.convert("1") + with open(fname, "wb") as f: + img_data.save(f, compress_level=1, optimize=False) + + +def _load_sft( + fname: str | Path, + fmt: str = "torch", + **kwargs, +) -> torch.Tensor: + """ + Loads a tensor from a safetensor file. + + Args: + fname (str | Path): The filename of the safetensor file to load. + fmt (str, optional): The format of the output data. Currently only "torch" is supported. + **kwargs: Additional keyword arguments (unused). + + Returns: + torch.Tensor: The loaded tensor. + + Raises: + AssertionError: If the file extension is not .sft or if fmt is not "torch". + """ + assert Path(fname).suffix == ".sft", "Only .sft (safetensor) is supported" + assert fmt == "torch", "Only torch format is supported for latent" + out = load_sft(str(fname)) + return out["latent"] + + +def _store_sft(fname: str | Path, data: torch.Tensor, **kwargs) -> None: + """ + Stores a tensor to a safetensor file. + + Args: + fname (str | Path): The filename to store the latent in. + data (torch.Tensor): The latent tensor to store. + **kwargs: Additional keyword arguments (unused). + + Raises: + AssertionError: If the file extension is not .sft or if data is not a torch.Tensor. + """ + assert Path(fname).suffix == ".sft", "Only .sft (safetensor) is supported" + assert isinstance(data, torch.Tensor) + save_sft(tensors={"latent": data}, filename=str(fname)) + + +def _store_depth(fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs) -> bool: + """ + Stores a depth map in an EXR file. + + Args: + fname (str or Path): The filename to save the depth map to. + data (numpy.ndarray, torch.tensor): The depth map to save. + + Returns: + bool: True if the depth map was saved successfully, False otherwise. + + Raises: + ValueError: If the input data does not have two dimensions after removing singleton dimensions. + """ + data_np = to_numpy(data, dtype=np.float32) + data_np = data_np.squeeze() # remove all 1-dim entries + if data_np.ndim != 2: + raise ValueError(f"Depth image needs to be 2d, but received: {data_np.shape}") + + if "params" in kwargs: + params = kwargs["params"] + else: + # use 16-bit with zip compression for depth maps + params = [ + cv2.IMWRITE_EXR_TYPE, + cv2.IMWRITE_EXR_TYPE_HALF, + cv2.IMWRITE_EXR_COMPRESSION, + cv2.IMWRITE_EXR_COMPRESSION_ZIP, + ] + + return _write_exr(fname, data_np, params=params) + + +def _load_depth( + fname: str | Path, fmt: str = "torch", **kwargs +) -> np.ndarray | torch.Tensor | Image.Image: + """ + Loads a depth image from an EXR file. + + Args: + fname (str or Path): The filename of the EXR file to load. + fmt (str): The format of the output data. Can be one of: + - "torch": Returns a PyTorch tensor. + - "np": Returns a NumPy array. + - "PIL": Returns a PIL Image object. + Defaults to "torch". + + Returns: + The loaded depth image in the specified output format. + + Raises: + ValueError: If the loaded depth image does not have two dimensions. + + Notes: + This function assumes that the EXR file contains a single-channel depth image. + """ + data = _read_exr(fname, fmt) + if (fmt != "PIL") and (data.ndim != 2): + raise ValueError(f"Depth image needs to be 2D, but loaded: {data.shape}") + return data + + +def _store_normals( + fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs +) -> bool: + """ + Stores a normals image in an EXR file. + + Args: + fname (str or Path): The filename to save the normals image to. + data (numpy.ndarray): The normals image data to save. Will be converted to a 32-bit float array. + + Returns: + bool: True if the normals image was saved successfully, False otherwise. + + Raises: + ValueError: If the input data has more than three dimensions after removing singleton dimensions. + ValueError: If the input data does not have exactly three channels. + ValueError: If the input data is not normalized (i.e., maximum absolute value exceeds 1). + + Notes: + This function assumes that the input data is in HWC (height, width, channels) format. + If the input data is in CHW (channels, height, width) format, it will be automatically transposed to HWC. + """ + data_np = to_numpy(data, dtype=np.float32) + data_np = data_np.squeeze() # remove all singleton dimensions + + if data_np.ndim != 3: + raise ValueError( + f"Normals image needs to be 3-dim but received: {data_np.shape}" + ) + + if (data_np.shape[0] == 3) and (data_np.shape[2] != 3): + # ensure HWC format + data_np = data_np.transpose(1, 2, 0) + + if data_np.shape[2] != 3: + raise ValueError( + f"Normals image needs have 3 channels but received: {data_np.shape}" + ) + + # We want to check that the norm values are either 1 (valid) or 0 (invalid values are 0s) + norm = np.linalg.norm(data_np, axis=-1) + is_one = np.isclose(norm, 1.0, atol=1e-3) + is_zero = np.isclose(norm, 0.0) + if not np.all([is_one | is_zero]): + raise ValueError("Normals image must be normalized") + + return _write_exr(fname, data_np) + + +def _load_normals( + fname: str | Path, fmt: str = "torch", **kwargs +) -> np.ndarray | torch.Tensor | Image.Image: + """ + Loads a normals image from an EXR file. + + Args: + fname (str or Path): The filename of the EXR file to load. + fmt (str): The format of the output data. Can be one of: + - "torch": Returns a PyTorch tensor. + - "np": Returns a NumPy array. + - "PIL": Returns a PIL Image object. + Defaults to "torch". + + Returns: + The loaded normals image in the specified output format. + + Raises: + Warning: If the loaded normals image has more than two dimensions. + + Notes: + This function assumes that the EXR file contains a 3-channel normals image. + """ + data = _read_exr(fname, fmt) + + if data.ndim != 3: + raise ValueError(f"Normals image needs to be 3-dim but received: {data.shape}") + + if data.shape[2] != 3: + raise ValueError( + f"Normals image needs have 3 channels but received: {data.shape}" + ) + + return data + + +def _load_numpy(fname: str | Path, allow_pickle: bool = False, **kwargs) -> np.ndarray: + """ + Loads a NumPy array from a file. + + Args: + fname (str or Path): The filename to load the NumPy array from. + allow_pickle (bool, optional): Whether to allow pickled objects in the NumPy file. + Defaults to False. + + Returns: + numpy.ndarray: The loaded NumPy array. + + Raises: + NotImplementedError: If the file suffix is not supported (i.e., not .npy or .npz). + + Notes: + This function supports loading NumPy arrays from .npy and .npz files. + For .npz files, it assumes that the array is stored under the key "arr_0". + """ + fname = Path(fname) + with open(fname, "rb") as fid: + if fname.suffix == ".npy": + return np.load(fid, allow_pickle=allow_pickle) + elif fname.suffix == ".npz": + return np.load(fid, allow_pickle=allow_pickle).get("arr_0") + else: + raise NotImplementedError(f"Numpy format not supported: {fname.suffix}") + + +def _store_numpy(fname: str | Path, data: np.ndarray, **kwargs) -> None: + """ + Stores a NumPy array in a file. + + Args: + fname (str or Path): The filename to store the NumPy array in. + data (numpy.ndarray): The NumPy array to store. + + Raises: + NotImplementedError: If the file suffix is not supported (i.e., not .npy or .npz). + + Notes: + This function supports storing NumPy arrays in .npy and .npz files. + For .npz files, it uses compression to reduce the file size. + """ + fname = Path(fname) + with open(fname, "wb") as fid: + if fname.suffix == ".npy": + np.save(fid, data) + elif fname.suffix == ".npz": + np.savez_compressed(fid, arr_0=data) + else: + raise NotImplementedError(f"Numpy format not supported: {fname.suffix}") + + +def _load_ptz(fname: str | Path, **kwargs) -> torch.Tensor: + """ + Loads a PyTorch tensor from a PTZ file. + + Args: + fname (str or Path): The filename to load the tensor from. + + Returns: + torch.Tensor: The loaded PyTorch tensor. + + Notes: + This function assumes that the PTZ file contains a PyTorch tensor saved using `torch.save`. + If the tensor was saved in a different format, this function may fail. + """ + with open(fname, "rb") as fid: + data = gzip.decompress(fid.read()) + ## Note: if the following line fails, save PyTorch tensors in PTZ instead of NumPy + return torch.load(io.BytesIO(data), map_location="cpu", weights_only=True) + + +def _store_ptz(fname: str | Path, data: torch.Tensor, **kwargs) -> None: + """ + Stores a PyTorch tensor in a PTZ file. + + Args: + fname (str or Path): The filename to store the tensor in. + data (torch.Tensor): The PyTorch tensor to store. + + Notes: + This function saves the tensor using `torch.save` and compresses it using gzip. + """ + with open(fname, "wb") as fid: + with gzip.open(fid, "wb") as gfid: + torch.save(data, gfid) + + +def _store_mmap(fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs) -> str: + """ + Stores matrix-shaped data in a memory-mapped file. + + Args: + fname (str or Path): The filename to store the data in. + data (numpy.ndarray): The matrix-shaped data to store. + + Returns: + str: The name of the stored memory-mapped file. + + Notes: + This function stores the data in a .npy file with a modified filename that includes the shape of the data. + The data is converted to float32 format before storing. + """ + fname = Path(fname) + # add dimensions to the file name for loading + data_np = to_numpy(data, dtype=np.float32) + shape_string = "x".join([str(dim) for dim in data_np.shape]) + mmap_name = f"{fname.stem}--{shape_string}.npy" + with open(fname.parent / mmap_name, "wb") as fid: + np.save(fid, data_np) + return mmap_name + + +def _load_mmap(fname: str | Path, **kwargs) -> np.memmap: + """ + Loads matrix-shaped data from a memory-mapped file. + + Args: + fname (str or Path): The filename of the memory-mapped file to load. + + Returns: + numpy.memmap: A memory-mapped array containing the loaded data. + + Notes: + This function assumes that the filename contains the shape of the data, separated by 'x' or ','. + It uses this information to create a memory-mapped array with the correct shape. + """ + shape_string = Path(Path(fname).name.split("--")[1]).stem + shape = [int(dim) for dim in shape_string.replace(",", "x").split("x")] + with open(fname, "rb") as fid: + return np.memmap(fid, dtype=np.float32, mode="r", shape=shape, offset=128) + + +def _store_scene_meta(fname: Path | str, scene_meta: dict[str, Any], **kwargs) -> None: + """ + Stores scene metadata in a readable file. + + Args: + fname (str or Path): The filename to store the scene metadata in. + scene_meta (dict): The scene metadata to store. + + Notes: + This function updates the "last_modified" field of the scene metadata to the current date and time before storing it. + It also removes the "frame_names" field from the scene metadata, as it is not necessary to store this information. + Creates a backup of the existing file before overwriting it. + """ + # update the modified date + scene_meta["last_modified"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + if "frame_names" in scene_meta: + del scene_meta["frame_names"] + + # create/overwrite backup + fname_path = Path(fname) + if fname_path.exists(): + backup_fname = fname_path.parent / f"_{fname_path.stem}_backup.json" + if backup_fname.exists(): + backup_fname.unlink() + fname_path.rename(backup_fname) + + _store_readable(fname, scene_meta) + + +def _load_scene_meta(fname: Path | str, **kwargs) -> dict[str, Any]: + """ + Loads scene metadata from a readable file. + + Args: + fname (str or Path): The filename to load the scene metadata from. + + Returns: + dict: The loaded scene metadata, including an additional "frame_names" field that maps frame names to their indices. + + Notes: + This function creates the "frame_names" field in the scene metadata for efficient lookup of frame indices by name. + """ + scene_meta = _load_readable_structured(fname) + # create the frame_name -> frame_idx for efficiency + scene_meta["frame_names"] = { + frame["frame_name"]: frame_idx + for frame_idx, frame in enumerate(scene_meta["frames"]) + } + return scene_meta + + +def _load_labeled_image( + fname: str | Path, + fmt: str = "torch", + resize: tuple[int, int] | None = None, + **kwargs, +) -> np.ndarray | torch.Tensor | Image.Image: + """ + Loads a labeled image from a PNG file. + + Args: + fname (str or Path): The filename to load the image from. + fmt (str): The format of the output data. Can be one of: + - "torch": Returns a PyTorch int32 tensor with shape (H, W). + - "np": Returns a NumPy int32 array with shape (H, W). + - "pil": Returns a PIL Image object. + Defaults to "torch". + resize (tuple, optional): A tuple of two integers representing the desired width and height of the image. + If None, the image is not resized. Defaults to None. + + Returns: + The loaded image in the specified output format. + + Raises: + NotImplementedError: If the specified output format is not supported. + RuntimeError: If the 'id_to_color_mapping' is missing in the PNG metadata. + + Notes: + The function expects the PNG file to contain metadata with a key 'id_to_color_mapping', + which maps from label ids to tuples of RGB values. + """ + with open(fname, "rb") as f: + pil_image = Image.open(f) + pil_image.load() + if pil_image.mode != "RGB": + raise OSError( + f"Expected a RGB image in {fname}, but instead found an image with mode {pil_image.mode}" + ) + + # Load id to RGB mapping + color_palette_json = pil_image.info.get("id_to_color_mapping", None) + if color_palette_json is None: + raise RuntimeError("'id_to_color_mapping' is missing in the PNG metadata.") + color_palette = json.loads(color_palette_json) + color_to_id_mapping = { + tuple(color): int(id) for id, color in color_palette.items() + } + + if resize is not None: + pil_image = pil_image.resize(resize, Image.NEAREST) + + if fmt == "pil": + return pil_image + + # Reverse the color mapping: map from RGB colors to ids + img_data = np.array(pil_image) + + # Create a lookup table for fast mapping + max_color_value = 256 # Assuming 8-bit per channel + lookup_table = np.full( + (max_color_value, max_color_value, max_color_value), + INVALID_ID, + dtype=np.int32, + ) + for color, index in color_to_id_mapping.items(): + lookup_table[color] = index + # Map colors to ids using the lookup table + img_data = lookup_table[img_data[..., 0], img_data[..., 1], img_data[..., 2]] + + if fmt == "np": + return img_data + elif fmt == "torch": + return torch.from_numpy(img_data) + else: + raise NotImplementedError(f"Image format not supported: {fmt}") + + +def _store_labeled_image( + fname: str | Path, + img_data: np.ndarray | torch.Tensor | Image.Image, + semantic_color_mapping: np.ndarray | None = None, + **kwargs, +) -> None: + """ + Stores a labeled image as a uint8 RGB PNG file. + + Args: + fname (str or Path): The filename to store the image in. + img_data (numpy.ndarray, torch.Tensor or PIL.Image.Image): The per-pixel label ids to store. + semantic_color_mapping (np.ndarray): Optional, preloaded NumPy array of semantic colors. + + Raises: + ValueError: If the file suffix is not supported (i.e., not .png). + RuntimeError: If the type of the image data is different from uint16, int16 or int32. + + Notes: + The function takes an image with per-pixel label ids and converts it into an RGB image + using a specified mapping from label ids to RGB colors. The resulting image is saved as + a PNG file, with the mapping stored as metadata. + """ + if Path(fname).suffix != ".png": + raise ValueError( + f"Only filenames with suffix .png allowed but received: {fname}" + ) + + if isinstance(img_data, Image.Image) and img_data.mode != "I;16": + raise RuntimeError( + f"The provided image does not seem to be a labeled image. The provided PIL image has mode {img_data.mode}." + ) + + if isinstance(img_data, np.ndarray) and img_data.dtype not in [ + np.uint16, + np.int16, + np.int32, + ]: + raise RuntimeError( + f"The provided NumPy array has type {img_data.dtype} but the expected type is np.uint16, np.int16 or np.int32." + ) + + if isinstance(img_data, torch.Tensor): + if img_data.dtype not in [torch.uint16, torch.int16, torch.int32]: + raise RuntimeError( + f"The provided PyTorch tensor has type {img_data.dtype} but the expected type is torch.uint16, torch.int16 or torch.int32." + ) + img_data = img_data.numpy() + + if semantic_color_mapping is None: + # Mapping from ids to colors not provided, load it now + semantic_color_mapping = load_semantic_color_mapping() + + img_data, color_palette = apply_id_to_color_mapping( + img_data, semantic_color_mapping + ) + pil_image = Image.fromarray(img_data, "RGB") + + # Create a PngInfo object to store metadata + meta = PngImagePlugin.PngInfo() + meta.add_text("id_to_color_mapping", json.dumps(color_palette)) + + pil_image.save(fname, pnginfo=meta) + + +def _load_generic_mesh(mesh_path: str | Path, **kwargs) -> trimesh.Trimesh: + """Load mesh with the trimesh library. + + Args: + mesh_path (str): Path to the mesh file + + Returns: + The trimesh object from trimesh.load(). + + Raises: + ValueError: If the file format is not supported. + """ + + # needed to load big texture files + Image.MAX_IMAGE_PIXELS = None + + # load mesh with trimesh + mesh_data = trimesh.load(mesh_path, process=False) + + return mesh_data + + +def _store_generic_mesh( + file_path: str | Path, mesh_data: dict | trimesh.Trimesh, **kwargs +) -> None: + """ + Dummy function for storing generic mesh data. + + Args: + file_path (str): The filename to store the mesh in. + mesh_data (dict): Dictionary containing mesh data. + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: This function is not implemented yet. + """ + raise NotImplementedError("Storing generic meshes is not implemented yet.") + + +def _load_labeled_mesh( + file_path: str | Path, + fmt: str = "torch", + palette: str = "rgb", + **kwargs, +) -> dict | trimesh.Trimesh: + """ + Loads a mesh from a labeled mesh file (PLY binary format). + + Args: + file_path (str): The path to the labeled mesh file (.ply). + fmt (str): Output format of the mesh data. Can be one of: + - "torch": Returns a dict of PyTorch tensors containing mesh data. + - "np": Returns a dict of NumPy arrays containing mesh data. + - "trimesh": Returns a trimesh mesh object. + Defaults to "torch". + palette (str): Output color of the trimesh mesh data. Can be one of: + - "rgb": Colors the mesh with original rgb colors + - "semantic_class": Colors the mesh with semantic class colors + - "instance": Colors the mesh with semantic instance colors + Applied only when fmt is "trimesh". + + Returns: + The loaded mesh in the specified output format. + + Raises: + NotImplementedError: If the specified output format is not supported. + + Notes: + This function reads a binary PLY file with vertex position, color, and optional + semantic class and instance IDs. The faces are stored as lists of vertex indices. + """ + # load data (NOTE: define known_list_len to enable faster read) + ply_data = PlyData.read(file_path, known_list_len={"face": {"vertex_indices": 3}}) + + # get vertices + vertex_data = ply_data["vertex"].data + vertices = np.column_stack( + (vertex_data["x"], vertex_data["y"], vertex_data["z"]) + ).astype(np.float32) + + # initialize output data + mesh_data = {} + mesh_data["is_labeled_mesh"] = True + mesh_data["vertices"] = vertices + + # get faces if available + if "face" in ply_data: + faces = np.asarray(ply_data["face"].data["vertex_indices"]).astype(np.int32) + mesh_data["faces"] = faces + + # get rgb colors if available + if all(color in vertex_data.dtype.names for color in ["red", "green", "blue"]): + vertices_color = np.column_stack( + (vertex_data["red"], vertex_data["green"], vertex_data["blue"]) + ).astype(np.uint8) + mesh_data["vertices_color"] = vertices_color + + # get vertices class and instance if available + if "semantic_class_id" in vertex_data.dtype.names: + vertices_class = vertex_data["semantic_class_id"].astype(np.int32) + mesh_data["vertices_semantic_class_id"] = vertices_class + + if "instance_id" in vertex_data.dtype.names: + vertices_instance = vertex_data["instance_id"].astype(np.int32) + mesh_data["vertices_instance_id"] = vertices_instance + + # get class colors if available + if all( + color in vertex_data.dtype.names + for color in [ + "semantic_class_red", + "semantic_class_green", + "semantic_class_blue", + ] + ): + vertices_semantic_class_color = np.column_stack( + ( + vertex_data["semantic_class_red"], + vertex_data["semantic_class_green"], + vertex_data["semantic_class_blue"], + ) + ).astype(np.uint8) + mesh_data["vertices_semantic_class_color"] = vertices_semantic_class_color + + # get instance colors if available + if all( + color in vertex_data.dtype.names + for color in ["instance_red", "instance_green", "instance_blue"] + ): + vertices_instance_color = np.column_stack( + ( + vertex_data["instance_red"], + vertex_data["instance_green"], + vertex_data["instance_blue"], + ) + ).astype(np.uint8) + mesh_data["vertices_instance_color"] = vertices_instance_color + + # convert data into output format (if needed) + if fmt == "np": + return mesh_data + elif fmt == "torch": + return {k: torch.tensor(v) for k, v in mesh_data.items()} + elif fmt == "trimesh": + trimesh_mesh = trimesh.Trimesh( + vertices=mesh_data["vertices"], faces=mesh_data["faces"] + ) + # color the mesh according to the palette + if palette == "rgb": + # original rgb colors + if "vertices_color" in mesh_data: + trimesh_mesh.visual.vertex_colors = mesh_data["vertices_color"] + else: + raise ValueError( + f"Palette {palette} could not be applied. Missing vertices_color in mesh data." + ) + elif palette == "semantic_class": + # semantic class colors + if "vertices_semantic_class_color" in mesh_data: + trimesh_mesh.visual.vertex_colors = mesh_data[ + "vertices_semantic_class_color" + ] + else: + raise ValueError( + f"Palette {palette} could not be applied. Missing vertices_semantic_class_color in mesh data." + ) + elif palette == "instance": + # semantic instance colors + if "vertices_instance_color" in mesh_data: + trimesh_mesh.visual.vertex_colors = mesh_data["vertices_instance_color"] + else: + raise ValueError( + f"Palette {palette} could not be applied. Missing vertices_instance_color in mesh data." + ) + else: + raise ValueError(f"Invalid palette: {palette}.") + return trimesh_mesh + else: + raise NotImplementedError(f"Labeled mesh format not supported: {fmt}") + + +def _store_labeled_mesh(file_path: str | Path, mesh_data: dict, **kwargs) -> None: + """ + Stores a mesh in WAI format (PLY binary format). + + Args: + file_path (str): The filename to store the mesh in. + mesh_data (dict): Dictionary containing mesh data with keys: + - 'vertices' (numpy.ndarray): Array of vertex coordinates with shape (N, 3). + - 'faces' (numpy.ndarray, optional): Array of face indices. + - 'vertices_color' (numpy.ndarray, optional): Array of vertex colors with shape (N, 3). + - 'vertices_semantic_class_id' (numpy.ndarray, optional): Array of semantic classes for each vertex with shape (N). + - 'vertices_instance_id' (numpy.ndarray, optional): Array of instance IDs for each vertex with shape (N). + - 'vertices_semantic_class_color' (numpy.ndarray, optional): Array of vertex semantic class colors with shape (N, 3). + - 'vertices_instance_color' (numpy.ndarray, optional): Array of vertex instance colors with shape (N, 3). + + Notes: + This function writes a binary PLY file with vertex position, color, and optional + semantic class and instance IDs. The faces are stored as lists of vertex indices. + """ + # Validate input data + if "vertices" not in mesh_data: + raise ValueError("Mesh data must contain 'vertices'") + + # create vertex data with properties + vertex_dtype = [("x", "f4"), ("y", "f4"), ("z", "f4")] + if "vertices_color" in mesh_data: + vertex_dtype.extend([("red", "u1"), ("green", "u1"), ("blue", "u1")]) + if "vertices_semantic_class_id" in mesh_data: + vertex_dtype.append(("semantic_class_id", "i4")) + if "vertices_instance_id" in mesh_data: + vertex_dtype.append(("instance_id", "i4")) + if "vertices_semantic_class_color" in mesh_data: + vertex_dtype.extend( + [ + ("semantic_class_red", "u1"), + ("semantic_class_green", "u1"), + ("semantic_class_blue", "u1"), + ] + ) + if "vertices_instance_color" in mesh_data: + vertex_dtype.extend( + [("instance_red", "u1"), ("instance_green", "u1"), ("instance_blue", "u1")] + ) + vertex_count = len(mesh_data["vertices"]) + vertex_data = np.zeros(vertex_count, dtype=vertex_dtype) + + # vertex positions + vertex_data["x"] = mesh_data["vertices"][:, 0] + vertex_data["y"] = mesh_data["vertices"][:, 1] + vertex_data["z"] = mesh_data["vertices"][:, 2] + + # vertex colors + if "vertices_color" in mesh_data: + vertex_data["red"] = mesh_data["vertices_color"][:, 0] + vertex_data["green"] = mesh_data["vertices_color"][:, 1] + vertex_data["blue"] = mesh_data["vertices_color"][:, 2] + + # vertex class + if "vertices_semantic_class_id" in mesh_data: + vertex_data["semantic_class_id"] = mesh_data["vertices_semantic_class_id"] + + # vertex instance + if "vertices_instance_id" in mesh_data: + vertex_data["instance_id"] = mesh_data["vertices_instance_id"] + + # vertex class colors + if "vertices_semantic_class_color" in mesh_data: + vertex_data["semantic_class_red"] = mesh_data["vertices_semantic_class_color"][ + :, 0 + ] + vertex_data["semantic_class_green"] = mesh_data[ + "vertices_semantic_class_color" + ][:, 1] + vertex_data["semantic_class_blue"] = mesh_data["vertices_semantic_class_color"][ + :, 2 + ] + + # vertex instance colors + if "vertices_instance_color" in mesh_data: + vertex_data["instance_red"] = mesh_data["vertices_instance_color"][:, 0] + vertex_data["instance_green"] = mesh_data["vertices_instance_color"][:, 1] + vertex_data["instance_blue"] = mesh_data["vertices_instance_color"][:, 2] + + # initialize data to save + vertex_element = PlyElement.describe(vertex_data, "vertex") + data_to_save = [vertex_element] + + # faces data + if "faces" in mesh_data: + face_dtype = [("vertex_indices", "i4", (3,))] + face_data = np.zeros(len(mesh_data["faces"]), dtype=face_dtype) + face_data["vertex_indices"] = mesh_data["faces"] + face_element = PlyElement.describe(face_data, "face") + data_to_save.append(face_element) + + # Create and write a binary PLY file + ply_data = PlyData(data_to_save, text=False) + ply_data.write(file_path) + + +def _get_method( + fname: Path | str, format_type: str | None = None, load: bool = True +) -> Callable: + """ + Returns a method for loading or storing data in a specific format. + + Args: + fname (str or Path): The filename to load or store data from/to. + format_type (str, optional): The format of the data. If None, it will be inferred from the file extension. + Defaults to None. + load (bool, optional): Whether to return a method for loading or storing data. + Defaults to True. + + Returns: + callable: A method for loading or storing data in the specified format. + + Raises: + ValueError: If the format cannot be inferred from the file extension. + NotImplementedError: If the specified format is not supported. + + Notes: + This function supports various formats, including readable files (JSON, YAML), images, NumPy arrays, + PyTorch tensors, memory-mapped files, and scene metadata. + """ + fname = Path(fname) + if format_type is None: + # use default formats + if fname.suffix in [".json", ".yaml", ".yml"]: + format_type = "readable" + elif fname.suffix in [".jpg", ".jpeg", ".png", ".webp"]: + format_type = "image" + elif fname.suffix in [".npy", ".npz"]: + format_type = "numpy" + elif fname.suffix == ".ptz": + format_type = "ptz" + elif fname.suffix == ".sft": + format_type = "sft" + elif fname.suffix == ".exr": + format_type = "scalar" + elif fname.suffix in [".glb", ".obj", ".ply"]: + format_type = "mesh" + else: + raise ValueError(f"Cannot infer format for {fname}") + methods = { + "readable": (_load_readable, _store_readable), + "scalar": (_read_exr, _write_exr), + "image": (_load_image, _store_image), + "binary": (_load_binary_mask, _store_binary_mask), + "latent": (_load_sft, _store_sft), + "depth": (_load_depth, _store_depth), + "normals": (_load_normals, _store_normals), + "numpy": (_load_numpy, _store_numpy), + "ptz": (_load_ptz, _store_ptz), + "sft": (_load_sft, _store_sft), + "mmap": (_load_mmap, _store_mmap), + "scene_meta": (_load_scene_meta, _store_scene_meta), + "labeled_image": (_load_labeled_image, _store_labeled_image), + "mesh": (_load_generic_mesh, _store_generic_mesh), + "labeled_mesh": (_load_labeled_mesh, _store_labeled_mesh), + } + try: + return methods[format_type][0 if load else 1] + except KeyError as e: + raise NotImplementedError(f"Format not supported: {format_type}") from e diff --git a/mapanything/utils/wai/m_ops.py b/mapanything/utils/wai/m_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7469c4843a0bee1360cacce5d25ce825b4603305 --- /dev/null +++ b/mapanything/utils/wai/m_ops.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + + +def m_dot( + transform: torch.Tensor, + points: torch.Tensor | list, + maintain_shape: bool = False, +) -> torch.Tensor | list: + """ + Apply batch matrix multiplication between transform matrices and points. + + Args: + transform: Batch of transformation matrices [..., 3/4, 3/4] + points: Batch of points [..., N, 3] or a list of points + maintain_shape: If True, preserves the original shape of points + + Returns: + Transformed points with shape [..., N, 3] or a list of transformed points + """ + if isinstance(points, list): + return [m_dot(t, p, maintain_shape) for t, p in zip(transform, points)] + + # Store original shape and flatten batch dimensions + orig_shape = points.shape + batch_dims = points.shape[:-3] + + # Reshape to standard batch format + transform_flat = transform.reshape(-1, transform.shape[-2], transform.shape[-1]) + points_flat = points.reshape(transform_flat.shape[0], -1, points.shape[-1]) + + # Apply transformation + pts = torch.bmm( + transform_flat[:, :3, :3], + points_flat[..., :3].permute(0, 2, 1).to(transform_flat.dtype), + ).permute(0, 2, 1) + + if transform.shape[-1] == 4: + pts = pts + transform_flat[:, :3, 3].unsqueeze(1) + + # Restore original shape + if maintain_shape: + return pts.reshape(orig_shape) + else: + return pts.reshape(*batch_dims, -1, 3) + + +def m_unproject( + depth: torch.Tensor, + intrinsic: torch.Tensor, + cam2world: torch.Tensor = None, + img_grid: torch.Tensor = None, + valid: torch.Tensor = None, + H: int | None = None, + W: int | None = None, + img_feats: torch.Tensor = None, + maintain_shape: bool = False, +) -> torch.Tensor: + """ + Unproject 2D image points with depth values to 3D points in camera or world space. + + Args: + depth: Depth values, either a tensor of shape ...xHxW or a float value + intrinsic: Camera intrinsic matrix of shape ...x3x3 + cam2world: Optional camera-to-world transformation matrix of shape ...x4x4 + img_grid: Optional pre-computed image grid. If None, will be created + valid: Optional mask for valid depth values or minimum depth threshold + H: Image height (required if depth is a scalar) + W: Image width (required if depth is a scalar) + img_feats: Optional image features to append to 3D points + maintain_shape: If True, preserves the original shape of points + + Returns: + 3D points in camera or world space, with optional features appended + """ + # Get device and shape information from intrinsic matrix + device = intrinsic.device + pre_shape = intrinsic.shape[:-2] # Batch dimensions + + # Validate inputs + if isinstance(depth, (int, float)) and H is None: + raise ValueError("H must be provided if depth is a scalar") + + # Determine image dimensions from depth if not provided + if isinstance(depth, torch.Tensor) and H is None: + H, W = depth.shape[-2:] + + # Create image grid if not provided + if img_grid is None: + # Create coordinate grid with shape HxWx3 (last dimension is homogeneous) + img_grid = _create_image_grid(H, W, device) + # Add homogeneous coordinate + img_grid = torch.cat([img_grid, torch.ones_like(img_grid[..., :1])], -1) + + # Expand img_grid to match batch dimensions of intrinsic + if img_grid.dim() <= intrinsic.dim(): + img_grid = img_grid.unsqueeze(0) + img_grid = img_grid.expand(*pre_shape, *img_grid.shape[-3:]) + + # Handle valid mask or minimum depth threshold + depth_mask = None + if valid is not None: + if isinstance(valid, float): + # Create mask for minimum depth value + depth_mask = depth > valid + elif isinstance(valid, torch.Tensor): + depth_mask = valid + + # Apply mask to image grid and other inputs + img_grid = masking(img_grid, depth_mask, dim=intrinsic.dim()) + if not isinstance(depth, (int, float)): + depth = masking(depth, depth_mask, dim=intrinsic.dim() - 1) + if img_feats is not None: + img_feats = masking(img_feats, depth_mask, dim=intrinsic.dim() - 1) + + # Unproject 2D points to 3D camera space + cam_pts: torch.Tensor = m_dot( + m_inverse_intrinsics(intrinsic), + img_grid[..., [1, 0, 2]], + maintain_shape=True, + ) + # Scale by depth values + cam_pts = mult(cam_pts, depth.unsqueeze(-1)) + + # Transform to world space if cam2world is provided + if cam2world is not None: + cam_pts = m_dot(cam2world, cam_pts, maintain_shape=True) + + # Append image features if provided + if img_feats is not None: + if isinstance(cam_pts, list): + if isinstance(cam_pts[0], list): + # Handle nested list case + result = [] + for batch_idx, batch in enumerate(cam_pts): + batch_result = [] + for view_idx, view in enumerate(batch): + batch_result.append( + torch.cat([view, img_feats[batch_idx][view_idx]], -1) + ) + result.append(batch_result) + cam_pts = result + else: + # Handle single list case + cam_pts = [ + torch.cat([pts, feats], -1) + for pts, feats in zip(cam_pts, img_feats) + ] + else: + # Handle tensor case + cam_pts = torch.cat([cam_pts, img_feats], -1) + + if maintain_shape: + return cam_pts + + # Flatten last dimension + return cam_pts.reshape(*pre_shape, -1, 3) + + +def m_project( + world_pts: torch.Tensor, + intrinsic: torch.Tensor, + world2cam: torch.Tensor | None = None, + maintain_shape: bool = False, +) -> torch.Tensor: + """ + Project 3D world points to 2D image coordinates. + + Args: + world_pts: 3D points in world coordinates + intrinsic: Camera intrinsic matrix + world2cam: Optional transformation from world to camera coordinates + maintain_shape: If True, preserves the original shape of points + + Returns: + Image points with coordinates in img_y,img_x,z order + """ + # Transform points from world to camera space if world2cam is provided + cam_pts: torch.Tensor = world_pts + if world2cam is not None: + cam_pts = m_dot(world2cam, world_pts, maintain_shape=maintain_shape) + + # Get shapes to properly expand intrinsics + shared_dims = intrinsic.shape[:-2] + extra_dims = cam_pts.shape[len(shared_dims) : -1] + + # Expand intrinsics to match cam_pts shape + expanded_intrinsic = intrinsic.view(*shared_dims, *([1] * len(extra_dims)), 3, 3) + expanded_intrinsic = expanded_intrinsic.expand(*shared_dims, *extra_dims, 3, 3) + + # Project points from camera space to image space + depth_abs = cam_pts[..., 2].abs().clamp(min=1e-5) + return torch.stack( + [ + expanded_intrinsic[..., 1, 1] * cam_pts[..., 1] / depth_abs + + expanded_intrinsic[..., 1, 2], + expanded_intrinsic[..., 0, 0] * cam_pts[..., 0] / depth_abs + + expanded_intrinsic[..., 0, 2], + cam_pts[..., 2], + ], + -1, + ) + + +def in_image( + image_pts: torch.Tensor | list, + H: int, + W: int, + min_depth: float = 0.0, +) -> torch.Tensor | list: + """ + Check if image points are within the image boundaries. + + Args: + image_pts: Image points in pixel coordinates + H: Image height + W: Image width + min_depth: Minimum valid depth + + Returns: + Boolean mask indicating which points are within the image + """ + is_list = isinstance(image_pts, list) + if is_list: + return [in_image(pts, H, W, min_depth=min_depth) for pts in image_pts] + + in_image_mask = ( + torch.all(image_pts >= 0, -1) + & (image_pts[..., 0] < H) + & (image_pts[..., 1] < W) + ) + if (min_depth is not None) and image_pts.shape[-1] == 3: + in_image_mask &= image_pts[..., 2] > min_depth + return in_image_mask + + +def _create_image_grid(H: int, W: int, device: torch.device) -> torch.Tensor: + """ + Create a coordinate grid for image pixels. + + Args: + H: Image height + W: Image width + device: Computation device + + Returns: + Image grid with shape HxWx3 (last dimension is homogeneous) + """ + y_coords = torch.arange(H, device=device) + x_coords = torch.arange(W, device=device) + + # Use meshgrid with indexing="ij" for correct orientation + y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij") + + # Stack coordinates and add homogeneous coordinate + img_grid = torch.stack([y_grid, x_grid, torch.ones_like(y_grid)], dim=-1) + + return img_grid + + +def masking( + X: torch.Tensor | list, + mask: torch.Tensor | list, + dim: int = 3, +) -> torch.Tensor | list: + """ + Apply a Boolean mask to tensor or list elements. + Handles nested structures by recursively applying the mask. + + Args: + X: Input tensor or list to be masked + mask: Boolean mask to apply + dim: Dimension threshold for recursive processing + + Returns: + Masked tensor or list with the same structure as input + """ + if isinstance(X, list) or (isinstance(X, torch.Tensor) and X.dim() >= dim): + return [masking(x, m, dim) for x, m in zip(X, mask)] + return X[mask] + + +def m_inverse_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor: + """ + Compute the inverse of camera intrinsics matrices analytically. + This is much faster than using torch.inverse() for intrinsics matrices. + + The intrinsics matrix has the form: + K = [fx s cx] + [0 fy cy] + [0 0 1] + + And its inverse is: + K^-1 = [1/fx -s/(fx*fy) (s*cy-cx*fy)/(fx*fy)] + [0 1/fy -cy/fy ] + [0 0 1 ] + + Args: + intrinsics: Camera intrinsics matrices of shape [..., 3, 3] + + Returns: + Inverse intrinsics matrices of shape [..., 3, 3] + """ + # Extract the components of the intrinsics matrix + fx = intrinsics[..., 0, 0] + s = intrinsics[..., 0, 1] # skew, usually 0 + cx = intrinsics[..., 0, 2] + fy = intrinsics[..., 1, 1] + cy = intrinsics[..., 1, 2] + + # Create output tensor with same shape and device + inv_intrinsics = torch.zeros_like(intrinsics) + + # Compute the inverse analytically + inv_intrinsics[..., 0, 0] = 1.0 / fx + inv_intrinsics[..., 0, 1] = -s / (fx * fy) + inv_intrinsics[..., 0, 2] = (s * cy - cx * fy) / (fx * fy) + inv_intrinsics[..., 1, 1] = 1.0 / fy + inv_intrinsics[..., 1, 2] = -cy / fy + inv_intrinsics[..., 2, 2] = 1.0 + + return inv_intrinsics + + +def mult( + A: torch.Tensor | np.ndarray | list | float | int, + B: torch.Tensor | np.ndarray | list | float | int, +) -> torch.Tensor | np.ndarray | list | float | int: + """ + Multiply two objects with support for lists, tensors, arrays, and scalars. + Handles nested structures by recursively applying multiplication. + + Args: + A: First operand (tensor, array, list, or scalar) + B: Second operand (tensor, array, list, or scalar) + + Returns: + Result of multiplication with the same structure as inputs + """ + if isinstance(A, list) and isinstance(B, (int, float)): + return [mult(a, B) for a in A] + if isinstance(B, list) and isinstance(A, (int, float)): + return [mult(A, b) for b in B] + if isinstance(A, list) and isinstance(B, list): + return [mult(a, b) for a, b in zip(A, B)] + return A * B diff --git a/mapanything/utils/wai/ops.py b/mapanything/utils/wai/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fc50da51513473dec9a0512085d8ac03b97cece3 --- /dev/null +++ b/mapanything/utils/wai/ops.py @@ -0,0 +1,485 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +This utils script contains PORTAGE of wai-core ops methods for MapAnything. +""" + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + + +def to_numpy( + data: torch.Tensor | np.ndarray | int | float, + dtype: np.dtype | str | type = np.float32, +) -> np.ndarray: + """ + Convert data to a NumPy array with the specified dtype (default: float32). + + This function handles conversion from NumPy arrays and PyTorch tensors to a NumPy array. + + Args: + data: Input data (torch.Tensor, np.ndarray, or scalar) + dtype: Target data type (NumPy dtype, str, or type). Default: np.float32. + + Returns: + Converted data as NumPy array with specified dtype. + """ + # Set default dtype if not defined + assert dtype is not None, "dtype cannot be None" + dtype = np.dtype(dtype) + + # Handle torch.Tensor + if isinstance(data, torch.Tensor): + return data.detach().cpu().numpy().astype(dtype) + + # Handle numpy.ndarray + if isinstance(data, np.ndarray): + return data.astype(dtype) + + # Handle scalar values + if isinstance(data, (int, float)): + return np.array(data, dtype=dtype) + + raise NotImplementedError(f"Unsupported data type: {type(data)}") + + +def get_dtype_device( + data: torch.Tensor | np.ndarray | dict | list, +) -> tuple[torch.dtype | np.dtype | None, torch.device | str | type | None]: + """ + Determine the data type and device of the input data. + + This function recursively inspects the input data and determines its data type + and device. It handles PyTorch tensors, NumPy arrays, dictionaries, and lists. + + Args: + data: Input data (torch.Tensor, np.ndarray, dict, list, or other) + + Returns: + tuple: (dtype, device) where: + - dtype: The data type (torch.dtype or np.dtype) + - device: The device (torch.device, 'cpu', 'cuda:X', or np.ndarray) + + Raises: + ValueError: If tensors in a dictionary are on different CUDA devices + """ + if isinstance(data, torch.Tensor): + return data.dtype, data.device + + if isinstance(data, np.ndarray): + return data.dtype, np.ndarray + + if isinstance(data, dict): + dtypes = {get_dtype_device(v)[0] for v in data.values()} + devices = {get_dtype_device(v)[1] for v in data.values()} + cuda_devices = {device for device in devices if str(device).startswith("cuda")} + cpu_devices = {device for device in devices if str(device).startswith("cpu")} + if (len(cuda_devices) > 0) or (len(cpu_devices) > 0): + # torch.tensor + dtype = torch.float + if all(dtype == torch.half for dtype in dtypes): + dtype = torch.half + device = None + if len(cuda_devices) > 1: + raise ValueError("All tensors must be on the same device") + if len(cuda_devices) == 1: + device = list(cuda_devices)[0] + if (device is None) and (len(cpu_devices) == 1): + device = list(cpu_devices)[0] + else: + dtype = np.float32 + # Fix typo in numpy float16 check + if all(dtype == np.float16 for dtype in dtypes): + dtype = np.float16 + device = np.ndarray + + elif isinstance(data, list): + if not data: # Handle empty list case + return None, None + dtype, device = get_dtype_device(data[0]) + + else: + return np.float32, np.ndarray + + return dtype, device + + +def to_dtype_device( + data: torch.Tensor | np.ndarray | dict | list | int | float, + dtype: torch.dtype | np.dtype | str | None = None, + device: torch.device | str | type | None = None, + convert_scalar: bool = False, +) -> torch.Tensor | np.ndarray | dict | list | int | float: + """ + Convert data to specified dtype and device. + + This function handles conversion between numpy arrays and PyTorch tensors, + as well as recursive conversion for nested data structures like dictionaries and lists. + + Args: + data: Input data (torch.Tensor, np.ndarray, dict, list, or scalar) + dtype: Target data type (torch dtype, numpy dtype, or None) + device: Target device (torch device, 'cuda', 'cpu', np.ndarray, torch.Tensor, or None) + convert_scalar: Whether to convert scalar values (int, float) to tensors/arrays + + Returns: + Converted data with specified dtype and on specified device + """ + # Handle case where device is passed in dtype parameter + if device is None: + if dtype is None: + raise ValueError("Either `dtype` or `device` must be provided.") + + if str(dtype).startswith("cuda") or str(dtype).startswith("cpu"): + device = dtype + dtype = None + else: + raise NotImplementedError() + + # Set default dtype based on device + if dtype is None: + if device is not None and ( + str(device).startswith("cuda") or str(device).startswith("cpu") + ): + dtype = torch.float + else: + dtype = np.float32 + + # Handle torch.Tensor + if isinstance(data, torch.Tensor): + if device == np.ndarray: + return data.detach().cpu().numpy().astype(dtype) + return data.to(device=device, dtype=dtype) + + # Handle numpy.ndarray + elif isinstance(data, np.ndarray): + if device == torch.Tensor: + return torch.from_numpy(data).to(dtype=dtype, device=device) + return data.astype(dtype) + + # Handle dictionary (recursively) + elif isinstance(data, dict): + return { + k: to_dtype_device(v, dtype, device, convert_scalar=convert_scalar) + for k, v in data.items() + } + + # Handle list (recursively) + elif isinstance(data, list): + return [ + to_dtype_device(x, dtype, device, convert_scalar=convert_scalar) + for x in data + ] + + # Handle scalar values + else: + if convert_scalar and isinstance(data, (int, float)): + if device == np.ndarray: + # Fix: scalars don't have astype method + return np.array(data, dtype=dtype) + else: + return torch.tensor(data, dtype=dtype, device=device) + + # Return original data if no conversion was applied + return data + + +def crop( + data: np.ndarray | torch.Tensor | Image.Image, + bbox: tuple[int, int, int, int] | tuple[int, int], +) -> np.ndarray | torch.Tensor | Image.Image: + """ + Crop data of different formats (numpy arrays, PyTorch tensors, PIL Images) to a target size. + + Args: + data: Input data to resize (numpy.ndarray, torch.Tensor, or PIL.Image.Image) + size: Target size as tuple (offset_height, offset_width, height, width) or tuple (height, width) + + Returns: + Cropped data in the same format as the input + """ + if len(bbox) == 4: + offset_height, offset_width, target_height, target_width = bbox + elif len(bbox) == 2: + target_height, target_width = bbox + offset_height, offset_width = 0, 0 + else: + raise ValueError(f"Unsupported size length {len(bbox)}.") + + end_height = offset_height + target_height + end_width = offset_width + target_width + + if any([sz < 0 for sz in bbox]): + raise ValueError("Bounding box can't have negative values.") + + if isinstance(data, np.ndarray): + if ( + max(offset_height, end_height) > data.shape[0] + or max(offset_width, end_width) > data.shape[1] + ): + raise ValueError("Invalid bounding box.") + cropped_data = data[offset_height:end_height, offset_width:end_width, ...] + return cropped_data + + # Handle PIL images + elif isinstance(data, Image.Image): + if ( + max(offset_height, end_height) > data.size[1] + or max(offset_width, end_width) > data.size[0] + ): + raise ValueError("Invalid bounding box.") + return data.crop((offset_width, offset_height, end_width, end_height)) + + # Handle PyTorch tensors + elif isinstance(data, torch.Tensor): + if data.is_nested: + # special handling for nested tensors + return torch.stack([crop(nested_tensor, bbox) for nested_tensor in data]) + if ( + max(offset_height, end_height) > data.shape[-2] + or max(offset_width, end_width) > data.shape[-1] + ): + raise ValueError("Invalid bounding box.") + cropped_data = data[..., offset_height:end_height, offset_width:end_width] + return cropped_data + else: + raise TypeError(f"Unsupported data type '{type(data)}'.") + + +def to_torch_device_contiguous( + data_dict: dict[str, dict | np.ndarray | torch.Tensor], + device: torch.device | str, + contiguous: bool = False, +) -> dict[str, dict | torch.Tensor]: + """ + This function handles conversion between a dict of heterogeneous numpy arrays and torch tensors, + supporting recursion and creation of torch contiguous tensors. + + Args: + data: Input data (torch.Tensor, np.ndarray, dict, list, or scalar) + device: Target device (torch device, 'cuda', 'cpu') + + Returns: + A dict of torch tensors, optionally contiguous in memory and loaded on the specified device. + """ + + result_dict = {} + for k, v in data_dict.items(): + if isinstance(v, dict): + result_dict[k] = to_torch_device_contiguous(v, device, contiguous) + elif isinstance(v, np.ndarray): + result_dict[k] = torch.from_numpy(v).to(device) + if contiguous: + result_dict[k] = result_dict[k].contiguous() + elif isinstance(v, torch.Tensor): + result_dict[k] = v.to(device).contiguous() + else: + raise ValueError(f"Found an unsupported value type {type(v)=} for key {k}.") + return result_dict + + +def stack( + data: list[ + dict[str, torch.Tensor | np.ndarray] + | list[torch.Tensor | np.ndarray] + | tuple[torch.Tensor | np.ndarray] + ], +) -> dict[str, torch.Tensor | np.ndarray] | list[torch.Tensor | np.ndarray]: + """ + Stack a list of dictionaries into a single dictionary with stacked values. + Or when given a list of sublists, stack the sublists using torch or numpy stack + if the items are of equal size, or nested tensors if the items are PyTorch tensors + of different size. + + This utility function is similar to PyTorch's collate function, but specifically + designed for stacking dictionaries of numpy arrays or PyTorch tensors. + + Args: + data (list): A list of dictionaries with the same keys, where values are + either numpy arrays or PyTorch tensors. + OR + A list of sublist, where the values of sublists are PyTorch tensors + or np arrays. + + Returns: + dict: A dictionary with the same keys as input dictionaries, but with values + stacked along a new first dimension. + OR + list: If the input was a list with sublists, it returns a list with a stacked + output for each original input sublist. + + Raises: + ValueError: If dictionaries in the list have inconsistent keys. + NotImplementedError: If input is not a list or contains non-dictionary elements. + """ + if not isinstance(data, list): + raise NotImplementedError(f"Stack: Data type not supported: {data}") + + if len(data) == 0: + return data + + if all(isinstance(entry, dict) for entry in data): + stacked_data = {} + keys = list(data[0].keys()) + if any(set(entry.keys()) != set(keys) for entry in data): + raise ValueError("Data not consistent for stacking") + + for key in keys: + stacked_data[key] = [] + for entry in data: + stacked_data[key].append(entry[key]) + + # stack it according to data format + if all(isinstance(v, np.ndarray) for v in stacked_data[key]): + stacked_data[key] = np.stack(stacked_data[key]) + elif all(isinstance(v, torch.Tensor) for v in stacked_data[key]): + # Check if all tensors have the same shape + first_shape = stacked_data[key][0].shape + if all(tensor.shape == first_shape for tensor in stacked_data[key]): + stacked_data[key] = torch.stack(stacked_data[key]) + else: + # Use nested tensors if shapes are not consistent + stacked_data[key] = torch.nested.nested_tensor(stacked_data[key]) + return stacked_data + + if all(isinstance(entry, list) for entry in data): + # new stacked data will be a list with all of the sublist + stacked_data = [] + for sublist in data: + # stack it according to data format + if all(isinstance(v, np.ndarray) for v in sublist): + stacked_data.append(np.stack(sublist)) + elif all(isinstance(v, torch.Tensor) for v in sublist): + # Check if all tensors have the same shape + first_shape = sublist[0].shape + if all(tensor.shape == first_shape for tensor in sublist): + stacked_data.append(torch.stack(sublist)) + else: + # Use nested tensors if shapes are not consistent + stacked_data.append(torch.nested.nested_tensor(sublist)) + return stacked_data + + raise NotImplementedError(f"Stack: Data type not supported: {data}") + + +def resize( + data: np.ndarray | torch.Tensor | Image.Image, + size: tuple[int, int] | int | None = None, + scale: float | None = None, + modality_format: str | None = None, +) -> np.ndarray | torch.Tensor | Image.Image: + """ + Resize data of different formats (numpy arrays, PyTorch tensors, PIL Images) to a target size. + + Args: + data: Input data to resize (numpy.ndarray, torch.Tensor, or PIL.Image.Image) + size: Target size as tuple (height, width) or single int for long-side scaling + scale: Scale factor to apply to the original dimensions + modality_format: Type of data being resized ('depth', 'normals', or None) + Affects interpolation method used + + Returns: + Resized data in the same format as the input + + Raises: + ValueError: If neither size nor scale is provided, or if both are provided + TypeError: If data is not a supported type + """ + # Validate input parameters + if size is not None and scale is not None: + raise ValueError("Only one of size or scale should be provided.") + + # Calculate size from scale if needed + if size is None: + if scale is None: + raise ValueError("Either size or scale must be provided.") + + size = (1, 1) + if isinstance(data, (np.ndarray, torch.Tensor)): + size = (int(data.shape[-2] * scale), int(data.shape[-1] * scale)) + elif isinstance(data, Image.Image): + size = (int(data.size[1] * scale), int(data.size[0] * scale)) + else: + raise TypeError(f"Unsupported data type '{type(data)}'.") + + # Handle long-side scaling when size is a single integer + elif isinstance(size, int): + long_side = size + if isinstance(data, (np.ndarray, torch.Tensor)): + if isinstance(data, torch.Tensor) and data.is_nested: + raise ValueError( + "Long-side scaling not support for nested tensors, use fixed size instead." + ) + h, w = data.shape[-2], data.shape[-1] + elif isinstance(data, Image.Image): + w, h = data.size + else: + raise TypeError(f"Unsupported data type '{type(data)}'.") + if h > w: + size = (long_side, int(w * long_side / h)) + else: + size = (int(h * long_side / w), long_side) + + target_height, target_width = size + + # Set interpolation method based on modality + if modality_format in ["depth", "normals"]: + interpolation = Image.Resampling.NEAREST + torch_interpolation = "nearest" + else: + interpolation = Image.Resampling.LANCZOS + torch_interpolation = "bilinear" + + # Handle numpy arrays + if isinstance(data, np.ndarray): + pil_image = Image.fromarray(data) + resized_image = pil_image.resize((target_width, target_height), interpolation) + return np.array(resized_image) + + # Handle PIL images + elif isinstance(data, Image.Image): + return data.resize((target_width, target_height), interpolation) + + # Handle PyTorch tensors + elif isinstance(data, torch.Tensor): + if data.is_nested: + # special handling for nested tensors + return torch.stack( + [ + resize(nested_tensor, size, scale, modality_format) + for nested_tensor in data + ] + ) + original_dim = data.ndim + if original_dim == 2: # (H, W) + data = data.unsqueeze(0).unsqueeze(0) # Add channel and batch dimensions + elif original_dim == 3: # (C/B, H W) + if modality_format == "depth": + data = data.unsqueeze(1) # channel batch dimension + else: + data = data.unsqueeze(0) # Add batch dimension + resized_tensor = F.interpolate( + data, + size=(target_height, target_width), + mode=torch_interpolation, + align_corners=False if torch_interpolation != "nearest" else None, + ) + if original_dim == 2: + return resized_tensor.squeeze(0).squeeze( + 0 + ) # Remove batch and channel dimensions + elif original_dim == 3: + if modality_format == "depth": + return resized_tensor.squeeze(1) # Remove channel dimension + + return resized_tensor.squeeze(0) # Remove batch dimension + else: + return resized_tensor + + else: + raise TypeError(f"Unsupported data type '{type(data)}'.") diff --git a/mapanything/utils/wai/scene_frame.py b/mapanything/utils/wai/scene_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..091e82ed071a185789f42c4b7a5f5cdbda5d577c --- /dev/null +++ b/mapanything/utils/wai/scene_frame.py @@ -0,0 +1,436 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import re +from pathlib import Path +from typing import Any + +import numpy as np + +from mapanything.utils.wai.io import ( + _load_readable, + _load_scene_meta, + get_processing_state, +) + +logger = logging.getLogger(__name__) + + +def get_scene_frame_names( + cfg: dict | object, + root: Path | str | None = None, + scene_frames_fn: str | None = None, + keyframes: bool = True, +) -> dict[str, list[str | float]] | None: + """ + Retrieve scene frame names based on configuration and optional parameters. + + This function determines the scene frame names by resolving the scene frame file + and applying any necessary filters based on the provided configuration. + + Args: + cfg: Configuration object containing settings and parameters. + root: Optional root directory path. If not provided, it will be fetched from cfg. + scene_frames_fn: Optional scene frames file name. If not provided, it will be fetched from cfg. + keyframes: Optional, used only for a video. If True (default), return only keyframes (with camera poses). + + Returns: + A dictionary mapping scene names to their respective frame names. + """ + scene_frames_fn = ( + cfg.get("scene_frames_fn") if scene_frames_fn is None else scene_frames_fn + ) + scene_frame_names = None + if scene_frames_fn is not None: + # load scene_frames based on scene_frame file + scene_frame_names = _resolve_scene_frames_fn(scene_frames_fn) + + scene_names = get_scene_names( + cfg, + root=root, + scene_names=( + list(scene_frame_names.keys()) if scene_frame_names is not None else None + ), + ) + scene_frame_names = _resolve_scene_frame_names( + cfg, + scene_names, + root=root, + scene_frame_names=scene_frame_names, + keyframes=keyframes, + ) + return scene_frame_names + + +def get_scene_names( + cfg: dict | object, + root: Path | str | None = None, + scene_names: list[str] | None = None, + shuffle: bool = False, +) -> list[str]: + """ + Retrieve scene names based on the provided configuration and optional parameters. + + This function determines the scene names by checking the root directory for subdirectories + and applying any necessary filters based on the provided configuration. + + Args: + cfg: Configuration object containing settings and parameters. + root: Optional root directory path. If not provided, it will be fetched from cfg. + scene_names: Optional list of scene names. If not provided, it will be determined from the root directory. + shuffle: Optional bool. Default to False. If True, it will return the list of scene names in random order. + + Returns: + A list of scene names after applying any filters specified in the configuration. + """ + root = cfg.get("root") if root is None else root + if root is not None: + # Check if the root exists + if not Path(root).exists(): + raise IOError(f"Root directory does not exist: {root}") + + # Check if the root is a directory + if not Path(root).is_dir(): + raise IOError(f"Root directory is not a directory: {root}") + + if scene_names is None: + scene_filters = cfg.get("scene_filters") + if ( + scene_filters + and len(scene_filters) == 1 + and isinstance(scene_filters[0], list) + and all(isinstance(entry, str) for entry in scene_filters[0]) + ): + # Shortcut the scene_names if the scene_filters is only a list of scene names + scene_names = scene_filters[0] + else: + # List all subdirectories in the root as scenes + scene_names = sorted( + [entry.name for entry in os.scandir(root) if entry.is_dir()] + ) + # Filter scenes based on scene_filters + scene_names = _filter_scenes(root, scene_names, cfg.get("scene_filters")) + + # shuffle the list if needed (in place) + if shuffle: + random.shuffle(scene_names) + + return scene_names + + +def _filter_scenes( + root: Path | str, + scene_names: list[str], + scene_filters: tuple | list | None, +) -> list[str]: + if scene_filters is None: + return scene_names + + if not isinstance(scene_filters, (tuple, list)): + raise ValueError("scene_filters must be a list or tuple") + + for scene_filter in scene_filters: + if scene_filter in [None, "all"]: + pass + + elif isinstance(scene_filter, (tuple, list)): + if len(scene_filter) == 0: + raise ValueError("scene_filter cannot be empty") + + elif all(isinstance(x, int) for x in scene_filter): + if len(scene_filter) == 2: + # start/end index + scene_names = scene_names[scene_filter[0] : scene_filter[1]] + elif len(scene_filter) == 3: + # start/end/step + scene_names = scene_names[ + scene_filter[0] : scene_filter[1] : scene_filter[2] + ] + else: + # omegaconf conversion issue (converts strings to integers whenever possible) + if str(scene_filter[0]) in scene_names: + scene_names = [str(s) for s in scene_filter] + else: + raise ValueError( + "scene_filter format [start_idx, end_idx] or [start_idx, end_idx, step_size] or [scene_name1, scene_name2, ...]" + ) + + elif all(isinstance(x, str) for x in scene_filter): + # explicit scene names + if set(scene_filter).issubset(set(scene_names)): + scene_names = list(scene_filter) + else: + logger.warning( + f"Scene(s) not available: {set(scene_filter) - set(scene_names)}" + ) + scene_names = list(set(scene_names) & set(scene_filter)) + else: + raise TypeError( + f"Scene filter type not supported: {type(scene_filter)}" + ) + + elif isinstance(scene_filter, dict): + # reserved key words + if modality := scene_filter.get("exists"): + scene_names = [ + scene_name + for scene_name in scene_names + if Path(root, scene_name, modality).exists() + ] + + elif modality := scene_filter.get("exists_not"): + scene_names = [ + scene_name + for scene_name in scene_names + if not Path(root, scene_name, modality).exists() + ] + + elif process_filter := scene_filter.get("process_state"): + # filter for where has + (process_key, process_state) = process_filter + filtered_scene_names = [] + for scene_name in scene_names: + # load processing state and check for + processing_state = get_processing_state(Path(root, scene_name)) + if "*" in process_key: # regex matching + for process_name in processing_state: + if re.match(process_key, process_name): + process_key = process_name + break + if process_key not in processing_state: + continue + if processing_state[process_key]["state"] == process_state: + filtered_scene_names.append(scene_name) + scene_names = filtered_scene_names + + elif process_filter := scene_filter.get("process_state_not"): + # filter for where does not have + (process_key, process_state) = process_filter + filtered_scene_names = [] + for scene_name in scene_names: + # load processing state and check for + try: + processing_state = get_processing_state(Path(root, scene_name)) + except Exception: + filtered_scene_names.append(scene_name) + continue + if "*" in process_key: # regex matching + for process_name in processing_state: + if re.match(process_key, process_name): + process_key = process_name + break + if (process_key not in processing_state) or ( + processing_state[process_key]["state"] != process_state + ): + filtered_scene_names.append(scene_name) + scene_names = filtered_scene_names + + else: + raise ValueError(f"Scene filter not supported: {scene_filter}") + + elif isinstance(scene_filter, str): + # regex + scene_names = [ + scene_name + for scene_name in scene_names + if re.fullmatch(scene_filter, scene_name) + ] + else: + raise ValueError(f"Scene filter not supported: {scene_filter}") + + return scene_names + + +def _resolve_scene_frames_fn(scene_frames_fn: str) -> dict[str, list[str] | None]: + # support for file list in forms of lists or dicts + # containing scene_names [-> frames] + scene_frames_list = _load_readable(scene_frames_fn) + scene_frame_names = {} + + # TODO: The following code seems unreachable as scene_frames_list is always a dict + if isinstance(scene_frames_list, (list, tuple)): + for entry in scene_frames_list: + if isinstance(entry, (tuple, list)): + if ( + (len(entry) != 2) + or (not isinstance(entry[0], str)) + or (not isinstance(entry[1], list)) + ): + raise NotImplementedError( + "Only supports lists of [, [frame_names]]" + ) + scene_frame_names[entry[0]] = entry[1] + elif isinstance(entry, str): + scene_frame_names[entry] = None + elif isinstance(entry, dict): + # scene_name -> frames + raise NotImplementedError("Dict entry not supported yet") + else: + raise IOError(f"File list contains an entry of wrong format: {entry}") + + elif isinstance(scene_frames_list, dict): + # scene_name -> frames + for scene_name, frame in scene_frames_list.items(): + if isinstance(frame, (tuple, list)): + scene_frame_names[scene_name] = frame + elif isinstance(frame, dict): + if "frame_names" in frame: + scene_frame_names[scene_name] = frame["frame_names"] + else: + raise IOError(f"Scene frames format not supported: {frame}") + elif frame is None: + scene_frame_names[scene_name] = frame + else: + raise IOError(f"Scene frames format not supported: {frame}") + + else: + raise IOError(f"Scene frames format not supported: {scene_frames_list}") + + return scene_frame_names + + +def _resolve_scene_frame_names( + cfg: dict | object, + scene_names: list[str], + root: Path | str | None = None, + scene_frame_names: dict[str, list[str | float] | None] | None = None, + keyframes: bool = True, +) -> dict[str, list[str]]: + root = cfg.get("root") if root is None else root + if scene_frame_names is not None: + # restrict to the additional scene-level prefiltering + scene_frame_names = { + scene_name: scene_frame_names[scene_name] for scene_name in scene_names + } + # dict already loaded, apply additional filters + for scene_name, frame_names in scene_frame_names.items(): + if frame_names is None: + scene_meta = _load_scene_meta( + Path( + root, scene_name, cfg.get("scene_meta_path", "scene_meta.json") + ) + ) + frame_names = [frame["frame_name"] for frame in scene_meta["frames"]] + # TODO: add some logic for video keyframes + + scene_frame_names[scene_name] = _filter_frame_names( + root, frame_names, scene_name, cfg.get("frame_filters") + ) + else: + scene_frame_names = {} + for scene_name in scene_names: + scene_meta = _load_scene_meta( + Path(root, scene_name, cfg.get("scene_meta_path", "scene_meta.json")) + ) + if not keyframes: + frame_names = get_video_frames(scene_meta) + if frame_names is None: + keyframes = True + if keyframes: + frame_names = [frame["frame_name"] for frame in scene_meta["frames"]] + frame_names = _filter_frame_names( + root, frame_names, scene_name, cfg.get("frame_filters") + ) + scene_frame_names[scene_name] = frame_names + return scene_frame_names + + +def _filter_frame_names( + root: Path | str, + frame_names: list[str], + scene_name: str, + frame_filters: list | tuple | None, +) -> list[str]: + if frame_filters is None: + return frame_names + + if not isinstance(frame_filters, (tuple, list)): + raise ValueError("frame_filters must be a list or tuple") + + for frame_filter in frame_filters: + if frame_filter in [None, "all"]: + pass + + elif isinstance(frame_filter, (tuple, list)): + if len(frame_filter) == 0: + raise ValueError("frame_filter cannot be empty") + + if isinstance(frame_filter[0], int): + if len(frame_filter) == 2: + # start/end index + frame_names = frame_names[frame_filter[0] : frame_filter[1]] + + elif len(frame_filter) == 3: + # start/end/step + frame_names = frame_names[ + frame_filter[0] : frame_filter[1] : frame_filter[2] + ] + + else: + raise ValueError( + "frame_filter format [start_idx, end_idx] or [start_idx, end_idx,step_size]" + ) + else: + raise TypeError( + f"frame_filter[0] type not supported: {type(frame_filter[0])}" + ) + + elif isinstance(frame_filter, str): + # reserved key words + if match := re.match("exists: (.+)", frame_filter): + modality = match.group(1) + frame_names = [ + frame_name + for frame_name in frame_names + if any(Path(root, scene_name, modality).glob(f"{frame_name}.*")) + ] + + elif match := re.match("!exists: (.+)", frame_filter): + modality = match.group(1) + frame_names = [ + frame_name + for frame_name in frame_names + if not any(Path(root, scene_name, modality).glob(f"{frame_name}.*")) + ] + + else: # general regex + frame_names = [ + frame_name + for frame_name in frame_names + if re.match(frame_filter, frame_name) + ] + + else: + raise ValueError(f"frame_filter type not supported: {type(frame_filter)}") + + return frame_names + + +def get_video_frames(scene_meta: dict[str, Any]): + """ + Return names of video frames. + Args: + scene_meta: dictionary with scene_meat data. + + Returns: + A list of video frame names. + """ + image_modality = [mod for mod in scene_meta["frame_modalities"] if "image" in mod] + if len(image_modality) > 0: + image_modality = scene_meta["frame_modalities"][image_modality[0]] + if "chunks" in image_modality: + file_list = image_modality["chunks"] + else: + file_list = [image_modality] + frame_names = [] + for chunk in file_list: + start, end, fps = chunk["start"], chunk["end"], chunk["fps"] + chunk_frame_names = np.arange(start, end, 1.0 / fps).tolist() + frame_names += chunk_frame_names + return frame_names + return None diff --git a/mapanything/utils/wai/semantics.py b/mapanything/utils/wai/semantics.py new file mode 100644 index 0000000000000000000000000000000000000000..9508d961a61bc245a5c681fb900d182d2a849289 --- /dev/null +++ b/mapanything/utils/wai/semantics.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +This utils script contains PORTAGE of wai-core semantics methods for MapAnything. +""" + +import numpy as np +from PIL import Image + +INVALID_ID = 0 +INVALID_COLOR = (0, 0, 0) + + +def load_semantic_color_mapping(filename: str = "colors_fps_5k.npz") -> np.ndarray: + """Loads a precomputed colormap.""" + from mapanything.utils.wai.core import WAI_COLORMAP_PATH + + return np.load(WAI_COLORMAP_PATH / filename).get("arr_0") + + +def apply_id_to_color_mapping( + data_id: np.ndarray | Image.Image, + semantic_color_mapping: np.ndarray, +) -> tuple[np.ndarray, dict[int, tuple[int, int, int]]]: + """Maps semantic class/instance IDs to RGB colors.""" + if isinstance(data_id, Image.Image): + data_id = np.array(data_id) + + max_color_id = semantic_color_mapping.shape[0] - 1 + max_data_id = data_id.max() + if max_data_id > max_color_id: + raise ValueError("The provided color palette does not have enough colors!") + + # Create palette containing the id->color mappings of the input data IDs + unique_indices = np.unique(data_id).tolist() + color_palette = { + index: semantic_color_mapping[index, :].tolist() for index in unique_indices + } + + data_colors = semantic_color_mapping[data_id] + + return data_colors, color_palette diff --git a/mapanything/utils/warnings.py b/mapanything/utils/warnings.py new file mode 100644 index 0000000000000000000000000000000000000000..86e14562c6ef13a8048280056f8002e31a291824 --- /dev/null +++ b/mapanything/utils/warnings.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Wrapper utilities for warnings. +""" + +import warnings +from functools import wraps + + +def suppress_traceback(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + e.__traceback__ = e.__traceback__.tb_next.tb_next + raise + + return wrapper + + +class no_warnings: + def __init__(self, action: str = "ignore", **kwargs): + self.action = action + self.filter_kwargs = kwargs + + def __call__(self, fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter(self.action, **self.filter_kwargs) + return fn(*args, **kwargs) + + return wrapper + + def __enter__(self): + self.warnings_manager = warnings.catch_warnings() + self.warnings_manager.__enter__() + warnings.simplefilter(self.action, **self.filter_kwargs) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.warnings_manager.__exit__(exc_type, exc_val, exc_tb) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..360d064d7f0aa7ff964ae5377477842bdba37b19 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,27 @@ +# MapAnything 核心依赖 +torch>=2.0.0 +torchvision>=0.15.0 +numpy>=1.24.0 +opencv-python>=4.8.0 +pillow>=10.0.0 +pillow-heif>=0.13.0 + +# Gradio 和 UI +gradio>=4.0.0 +spaces>=0.19.0 + +# Transformers 和模型 +transformers>=4.35.0 +accelerate>=0.24.0 + +# 3D 处理 +trimesh>=4.0.0 +utils3d>=0.0.1 + +# 科学计算 +scipy>=1.11.0 +scikit-learn>=1.3.0 + +# 可视化 +matplotlib>=3.7.0 + diff --git a/scripts/gradio_app_v8.py b/scripts/gradio_app_v8.py new file mode 100644 index 0000000000000000000000000000000000000000..490c93aa9c2f69d343811fb27249c1be14a69c88 --- /dev/null +++ b/scripts/gradio_app_v8.py @@ -0,0 +1,2321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +import gc +import os +import shutil +import sys +import time +from datetime import datetime +from pathlib import Path +from collections import defaultdict +from typing import List, Dict, Tuple + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +import cv2 +import gradio as gr +import numpy as np +import spaces +import torch +import trimesh +from PIL import Image +from pillow_heif import register_heif_opener +from sklearn.cluster import DBSCAN + +from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals +from mapanything.utils.hf_utils.css_and_html import ( + get_gradio_theme, + GRADIO_CSS, +) +from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model, initialize_mapanything_local +from mapanything.utils.hf_utils.viz import predictions_to_glb +from mapanything.utils.image import load_images, rgb + +register_heif_opener() +sys.path.append("mapanything/") + +# ============================================================================ +# 全局配置 +# ============================================================================ + +# MapAnything Configuration +high_level_config = { + "path": "configs/train.yaml", + "hf_model_name": "facebook/map-anything", + "model_str": "mapanything", + "config_overrides": [ + "machine=aws", + "model=mapanything", + "model/task=images_only", + "model.encoder.uses_torch_hub=false", + ], + "checkpoint_name": "model.safetensors", + "config_name": "config.json", + "trained_with_amp": True, + "trained_with_amp_dtype": "bf16", + "data_norm_type": "dinov2", + "patch_size": 14, + "resolution": 518, +} + +# GroundingDINO 配置 - 从 HuggingFace 加载 +GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" +GROUNDING_DINO_BOX_THRESHOLD = 0.25 +GROUNDING_DINO_TEXT_THRESHOLD = 0.2 + +# SAM 配置 - 使用 HuggingFace 的 SAM 模型 +SAM_MODEL_ID = "facebook/sam-vit-huge" # 或使用 "facebook/sam-vit-base" 更快更小 + +DEFAULT_TEXT_PROMPT = "window . table . sofa . tv . book . door" + +# 通用物体列表(GroundingDINO 会检测图像中存在的物体) +COMMON_OBJECTS_PROMPT = ( + "person . face . hand . " + "chair . sofa . couch . bed . table . desk . cabinet . shelf . drawer . " + "door . window . wall . floor . ceiling . curtain . " + "tv . monitor . screen . computer . laptop . keyboard . mouse . " + "phone . tablet . remote . " + "lamp . light . chandelier . " + "book . magazine . paper . pen . pencil . " + "bottle . cup . glass . mug . plate . bowl . fork . knife . spoon . " + "vase . plant . flower . pot . " + "clock . picture . frame . mirror . " + "pillow . cushion . blanket . towel . " + "bag . backpack . suitcase . " + "box . basket . container . " + "shoe . hat . coat . " + "toy . ball . " + "car . bicycle . motorcycle . bus . truck . " + "tree . grass . sky . cloud . sun . " + "dog . cat . bird . " + "building . house . bridge . road . street . " + "sign . pole . bench" +) + +# V8: DBSCAN聚类配置 +# 根据物体类型设置不同的聚类半径(eps) +DBSCAN_EPS_CONFIG = { + 'sofa': 1.5, # 沙发:1.5米半径(大物体,同一个沙发的检测可能相距较远) + 'bed': 1.5, + 'couch': 1.5, + 'desk': 0.8, # 桌子:0.8米半径(中等物体) + 'table': 0.8, + 'chair': 0.6, # 椅子:0.6米(较小) + 'cabinet': 0.8, + 'window': 0.5, # 窗户:0.5米(位置固定,聚类严格) + 'door': 0.6, + 'tv': 0.6, + 'default': 1.0 # 默认:1米 +} + +DBSCAN_MIN_SAMPLES = 1 # 最小样本数(设为1意味着单个检测也能成为一个簇) + +ENABLE_VISUAL_FEATURES = False + +# 分割质量控制 +MIN_DETECTION_CONFIDENCE = 0.35 # 最低检测置信度(过滤误检测) +MIN_MASK_AREA = 100 # 最小mask面积(像素) + +# 匹配分数计算配置(用于备用匹配算法) +MATCH_3D_DISTANCE_THRESHOLD = 2.5 # 3D距离阈值(米) + +# 全局模型变量 +model = None +grounding_dino_model = None +grounding_dino_processor = None +sam_predictor = None + + +# ============================================================================ +# 分割模型加载 +# ============================================================================ + +def load_grounding_dino_model(device): + """加载 GroundingDINO 模型 - 从 HuggingFace""" + global grounding_dino_model, grounding_dino_processor + + if grounding_dino_model is not None: + print("✅ GroundingDINO 已加载") + return + + try: + from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection + + print(f"📥 从 HuggingFace 加载 GroundingDINO: {GROUNDING_DINO_MODEL_ID}") + grounding_dino_processor = AutoProcessor.from_pretrained(GROUNDING_DINO_MODEL_ID) + grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained( + GROUNDING_DINO_MODEL_ID + ).to(device).eval() + + print("✅ GroundingDINO 加载成功") + + except Exception as e: + print(f"❌ GroundingDINO 加载失败: {e}") + import traceback + traceback.print_exc() + + +def load_sam_model(device): + """加载 SAM 模型 - 从 HuggingFace""" + global sam_predictor + + if sam_predictor is not None: + print("✅ SAM 已加载") + return + + try: + from transformers import SamModel, SamProcessor + + print(f"📥 从 HuggingFace 加载 SAM: {SAM_MODEL_ID}") + sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device).eval() + sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID) + + # 将模型和处理器存储为全局变量 + sam_predictor = {'model': sam_model, 'processor': sam_processor} + print("✅ SAM 加载成功") + + except Exception as e: + print(f"❌ SAM 加载失败: {e}") + print(" SAM 功能将被禁用,将使用边界框作为mask") + import traceback + traceback.print_exc() + + +# ============================================================================ +# 分割功能 +# ============================================================================ + + +def generate_distinct_colors(n): + """生成 N 个视觉上区分度高的颜色(RGB,0-255)""" + import colorsys + if n == 0: + return [] + + colors = [] + for i in range(n): + hue = i / max(n, 1) + rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95) + rgb_color = tuple(int(c * 255) for c in rgb) + colors.append(rgb_color) + + return colors + + +def run_grounding_dino_detection(image_np, text_prompt, device): + """使用 GroundingDINO 进行检测""" + if grounding_dino_model is None or grounding_dino_processor is None: + print("⚠️ GroundingDINO 未加载") + return [] + + try: + print(f"🔍 GroundingDINO 检测: {text_prompt}") + + # 转换为 PIL Image + if image_np.dtype == np.uint8: + pil_image = Image.fromarray(image_np) + else: + pil_image = Image.fromarray((image_np * 255).astype(np.uint8)) + + # 预处理 + inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + # 推理 + with torch.no_grad(): + outputs = grounding_dino_model(**inputs) + + # 后处理 + results = grounding_dino_processor.post_process_grounded_object_detection( + outputs, + inputs["input_ids"], + threshold=GROUNDING_DINO_BOX_THRESHOLD, + text_threshold=GROUNDING_DINO_TEXT_THRESHOLD, + target_sizes=[pil_image.size[::-1]] + )[0] + + # 转换为统一格式 + detections = [] + boxes = results["boxes"].cpu().numpy() + scores = results["scores"].cpu().numpy() + labels = results["labels"] + + print(f"✅ 检测到 {len(boxes)} 个物体") + + for box, score, label in zip(boxes, scores, labels): + detection = { + 'bbox': box.tolist(), # [x1, y1, x2, y2] + 'label': label, + 'confidence': float(score) + } + detections.append(detection) + print(f" - {label}: {score:.2f}") + + return detections + + except Exception as e: + print(f"❌ GroundingDINO 检测失败: {e}") + import traceback + traceback.print_exc() + return [] + + +def run_sam_refinement(image_np, boxes): + """使用 SAM 精确分割 - HuggingFace Transformers 版本""" + if sam_predictor is None: + print("⚠️ SAM 未加载,使用 bbox 作为 mask") + # 使用 bbox 创建简单的矩形 mask + masks = [] + h, w = image_np.shape[:2] + for box in boxes: + x1, y1, x2, y2 = map(int, box) + mask = np.zeros((h, w), dtype=bool) + mask[y1:y2, x1:x2] = True + masks.append(mask) + return masks + + try: + print(f"🎯 SAM 精确分割 {len(boxes)} 个区域...") + + from PIL import Image + sam_model = sam_predictor['model'] + sam_processor = sam_predictor['processor'] + device = sam_model.device + + # 转换为 PIL Image + if image_np.dtype == np.uint8: + pil_image = Image.fromarray(image_np) + else: + pil_image = Image.fromarray((image_np * 255).astype(np.uint8)) + + masks = [] + for box in boxes: + x1, y1, x2, y2 = map(int, box) + input_boxes = [[[x1, y1, x2, y2]]] # SAM 需要的格式 + + # 处理输入 + inputs = sam_processor(pil_image, input_boxes=input_boxes, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + # 推理 + with torch.no_grad(): + outputs = sam_model(**inputs) + + # 后处理获取mask + pred_masks = sam_processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu() + )[0][0][0] # 取第一个mask + + masks.append(pred_masks.numpy() > 0.5) + + print(f"✅ SAM 分割完成") + return masks + + except Exception as e: + print(f"❌ SAM 分割失败: {e}") + import traceback + traceback.print_exc() + # Fallback to bbox masks + masks = [] + h, w = image_np.shape[:2] + for box in boxes: + x1, y1, x2, y2 = map(int, box) + mask = np.zeros((h, w), dtype=bool) + mask[y1:y2, x1:x2] = True + masks.append(mask) + return masks + + +def normalize_label(label): + """规范化标签,提取主要类别 + + 例如: 'sofa bed' -> 'sofa', 'desk cabinet' -> 'desk', 'table desk' -> 'table' + 'windows' -> 'window', 'chairs' -> 'chair' (单复数转换) + """ + label = label.strip().lower() + + # 优先级顺序(从高到低) + priority_labels = ['sofa', 'bed', 'table', 'desk', 'chair', 'cabinet', 'window', 'door'] + + # 查找标签中是否包含优先级类别 + for priority in priority_labels: + if priority in label: + return priority + + # 如果没有匹配,返回第一个词 + first_word = label.split()[0] if label else label + + # 处理常见复数形式 -> 单数 + if first_word.endswith('s') and len(first_word) > 1: + singular = first_word[:-1] # 去掉末尾的 's' + # 特殊复数规则 + if first_word.endswith('sses'): # glasses -> glass + singular = first_word[:-2] + elif first_word.endswith('ies'): # cherries -> cherry + singular = first_word[:-3] + 'y' + elif first_word.endswith('ves'): # shelves -> shelf + singular = first_word[:-3] + 'f' + + # 返回单数形式 + return singular + + return first_word + + +def labels_match(label1, label2): + """判断两个标签是否匹配(支持模糊匹配) + + 例如: 'sofa' 和 'sofa bed' 匹配 + 'desk' 和 'table desk' 匹配 + """ + norm1 = normalize_label(label1) + norm2 = normalize_label(label2) + return norm1 == norm2 + + +def compute_object_3d_center(points, mask): + """计算物体的 3D 中心点""" + masked_points = points[mask] + if len(masked_points) == 0: + return None + return np.median(masked_points, axis=0) + + +def compute_3d_bbox_iou(center1, size1, center2, size2): + """计算两个3D边界框的IoU""" + try: + # 计算边界框范围 [min, max] + min1 = center1 - size1 / 2 + max1 = center1 + size1 / 2 + min2 = center2 - size2 / 2 + max2 = center2 + size2 / 2 + + # 计算交集 + inter_min = np.maximum(min1, min2) + inter_max = np.minimum(max1, max2) + inter_size = np.maximum(0, inter_max - inter_min) + inter_volume = np.prod(inter_size) + + # 计算并集 + volume1 = np.prod(size1) + volume2 = np.prod(size2) + union_volume = volume1 + volume2 - inter_volume + + if union_volume == 0: + return 0.0 + + return inter_volume / union_volume + except: + return 0.0 + + +def compute_2d_mask_iou(mask1, mask2): + """计算两个2D mask的IoU""" + try: + intersection = np.logical_and(mask1, mask2).sum() + union = np.logical_or(mask1, mask2).sum() + if union == 0: + return 0.0 + return intersection / union + except: + return 0.0 + + +def extract_visual_features(image, mask, encoder): + """提取mask区域的视觉特征(使用DINOv2) + + Args: + image: [H, W, 3] float32 in [0, 1] or uint8 in [0, 255] + mask: [H, W] bool + encoder: DINOv2 encoder model + + Returns: + feature vector (1D numpy array) or None if failed + """ + try: + # 将mask区域裁剪出来 + coords = np.argwhere(mask) + if len(coords) == 0: + return None + + y_min, x_min = coords.min(axis=0) + y_max, x_max = coords.max(axis=0) + + # 确保裁剪区域有效 + if y_max <= y_min or x_max <= x_min: + return None + + # 裁剪并resize到224x224 + cropped = image[y_min:y_max+1, x_min:x_max+1] + + # 确保是 uint8 格式 + if cropped.dtype == np.float32 or cropped.dtype == np.float64: + if cropped.max() <= 1.0: + cropped = (cropped * 255).astype(np.uint8) + else: + cropped = cropped.astype(np.uint8) + + from PIL import Image + import torchvision.transforms as T + + pil_img = Image.fromarray(cropped) + pil_img = pil_img.resize((224, 224), Image.BILINEAR) + + # 转换为tensor + transform = T.Compose([ + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # 获取encoder的设备 + try: + device = next(encoder.parameters()).device + except: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + img_tensor = transform(pil_img).unsqueeze(0).to(device) # [1, 3, 224, 224] + + # 提取特征 - 使用 encoder 的前向传播 + with torch.no_grad(): + # 不同的encoder可能有不同的调用方式 + if hasattr(encoder, 'forward_features'): + # 如果有 forward_features 方法(标准 DINOv2) + features = encoder.forward_features(img_tensor) + else: + # 否则直接调用(DINOv2Encoder 只需要 input tensor) + features = encoder(img_tensor) + + # 如果 features 不是 tensor,尝试转换 + if not isinstance(features, torch.Tensor): + if isinstance(features, dict): + # 如果返回字典,尝试获取 'x' 或 'last_hidden_state' + features = features.get('x', features.get('last_hidden_state', None)) + if features is None: + return None + elif hasattr(features, 'data'): + # 如果是某种包装对象,尝试获取 data 属性 + features = features.data + else: + # 无法处理,返回 None + return None + + # 确保 features 是 tensor + if not isinstance(features, torch.Tensor): + return None + + # 确保是 4D tensor: [B, C, H, W] 或 3D: [B, N, C] 或 2D: [B, C] + if len(features.shape) == 4: + # [B, C, H, W] -> Global average pooling + features = features.mean(dim=[2, 3]) # [B, C] + elif len(features.shape) == 3: + # [B, N, C] -> 取平均 or 取 CLS token + features = features.mean(dim=1) # [B, C] + elif len(features.shape) == 2: + # [B, C] -> 已经是我们需要的格式 + pass + else: + # 不支持的 shape + return None + + # L2 normalize + features = features / (features.norm(dim=1, keepdim=True) + 1e-8) + + return features.cpu().numpy()[0] + + except Exception as e: + import traceback + print(f" ⚠️ 特征提取失败: {type(e).__name__}: {e}") + print(f" 调用栈:\n{traceback.format_exc()}") # 显示完整堆栈 + return None + + +def compute_feature_similarity(feat1, feat2): + """计算特征相似度(余弦相似度)""" + if feat1 is None or feat2 is None: + return 0.0 + try: + return np.dot(feat1, feat2) + except: + return 0.0 + + +def compute_match_score(obj1, obj2, weights={'distance': 0.5, 'iou_3d': 0.25, 'iou_2d': 0.15, 'feature': 0.1}): + """计算综合匹配分数(0-1) + + 动态调整权重:如果某个准则不可用,将其权重重新分配给其他准则 + """ + scores = {} + available_criteria = [] + + # 1. 3D距离分数(距离越近,分数越高) + if obj1.get('center_3d') is not None and obj2.get('center_3d') is not None: + distance = np.linalg.norm(obj1['center_3d'] - obj2['center_3d']) + scores['distance'] = max(0, 1 - distance / MATCH_3D_DISTANCE_THRESHOLD) + available_criteria.append('distance') + else: + scores['distance'] = 0.0 + + # 2. 3D IoU分数 + if obj1.get('bbox_3d') is not None and obj2.get('bbox_3d') is not None: + scores['iou_3d'] = compute_3d_bbox_iou( + obj1['bbox_3d']['center'], obj1['bbox_3d']['size'], + obj2['bbox_3d']['center'], obj2['bbox_3d']['size'] + ) + available_criteria.append('iou_3d') + else: + scores['iou_3d'] = 0.0 + + # 3. 2D IoU分数 + if obj1.get('mask_2d') is not None and obj2.get('mask_2d') is not None: + scores['iou_2d'] = compute_2d_mask_iou(obj1['mask_2d'], obj2['mask_2d']) + available_criteria.append('iou_2d') + else: + scores['iou_2d'] = 0.0 + + # 4. 视觉特征相似度 + if obj1.get('visual_feature') is not None and obj2.get('visual_feature') is not None: + scores['feature'] = compute_feature_similarity(obj1['visual_feature'], obj2['visual_feature']) + available_criteria.append('feature') + else: + scores['feature'] = 0.0 + + # 动态调整权重:只使用可用的准则 + if len(available_criteria) == 0: + return 0.0, scores + + # 重新归一化权重 + total_available_weight = sum(weights[k] for k in available_criteria) + if total_available_weight == 0: + return 0.0, scores + + adjusted_weights = {k: weights[k] / total_available_weight for k in available_criteria} + + # 加权求和 + total_score = sum(scores[k] * adjusted_weights[k] for k in available_criteria) + + return total_score, scores + + +def compute_adaptive_eps(centers, base_eps): + """自适应计算eps值 + + 根据物体的3D位置分布自动调整eps: + - 如果物体很分散,增大eps(避免过度分割) + - 如果物体很集中,使用默认eps + """ + if len(centers) <= 1: + return base_eps + + # 计算所有点之间的距离 + from scipy.spatial.distance import pdist + distances = pdist(centers) + + if len(distances) == 0: + return base_eps + + # 使用中位数距离作为参考 + median_dist = np.median(distances) + + # 自适应策略:如果中位数距离很大,说明物体分散,增大eps + # 如果中位数距离很小,说明物体集中,保持或减小eps + if median_dist > base_eps * 2: + # 物体非常分散,大幅增大eps(可能是同一物体的多视图检测) + adaptive_eps = min(median_dist * 0.6, base_eps * 2.5) + elif median_dist > base_eps: + # 物体较分散,适度增大eps + adaptive_eps = median_dist * 0.5 + else: + # 物体集中,使用默认eps + adaptive_eps = base_eps + + return adaptive_eps + + +def match_objects_across_views(all_view_detections): + """跨视图匹配相同物体(V8增强版:自适应DBSCAN聚类) + + V8增强版改进: + - 自适应eps:根据物体分布自动调整聚类半径 + - 智能合并:聚类后再检查是否有明显重复的簇 + - 置信度加权:使用置信度加权计算簇中心 + + Args: + all_view_detections: List[List[Dict]], 每个视图的检测结果 + + Returns: + object_id_map: Dict[view_idx][det_idx] = global_object_id + unique_objects: List[Dict] - 唯一物体列表 + """ + print("\n🔗 V8增强版: 自适应DBSCAN聚类匹配物体...") + + # 收集所有检测,按标签分组 + objects_by_label = defaultdict(list) + + for view_idx, detections in enumerate(all_view_detections): + for det_idx, det in enumerate(detections): + # 只处理有3D中心的物体 + if det.get('center_3d') is None: + continue + + norm_label = normalize_label(det['label']) + objects_by_label[norm_label].append({ + 'view_idx': view_idx, + 'det_idx': det_idx, + 'label': det['label'], + 'norm_label': norm_label, + 'center_3d': det['center_3d'], + 'confidence': det['confidence'], + 'bbox_3d': det.get('bbox_3d'), + }) + + if len(objects_by_label) == 0: + return {}, [] + + # V8: 对每种物体类别分别进行DBSCAN聚类 + object_id_map = defaultdict(dict) + unique_objects = [] + next_global_id = 0 + + for norm_label, objects in objects_by_label.items(): + print(f"\n 📦 处理 {norm_label}: {len(objects)} 个检测") + + # 如果只有1个检测,直接作为1个物体 + if len(objects) == 1: + obj = objects[0] + unique_objects.append({ + 'global_id': next_global_id, + 'label': obj['label'], + 'views': [(obj['view_idx'], obj['det_idx'])], + 'center_3d': obj['center_3d'], + }) + object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id + next_global_id += 1 + print(f" → 1个簇(单独检测)") + continue + + # 提取3D中心点坐标 + centers = np.array([obj['center_3d'] for obj in objects]) + + # 获取该类型的基础聚类半径 + base_eps = DBSCAN_EPS_CONFIG.get(norm_label, DBSCAN_EPS_CONFIG.get('default', 1.0)) + + # 🔥 V8增强:自适应计算eps + eps = compute_adaptive_eps(centers, base_eps) + + # DBSCAN聚类 + clustering = DBSCAN(eps=eps, min_samples=DBSCAN_MIN_SAMPLES, metric='euclidean') + cluster_labels = clustering.fit_predict(centers) + + # 统计簇 + n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) + n_noise = list(cluster_labels).count(-1) + + if eps != base_eps: + print(f" → {n_clusters} 个簇 (基础eps={base_eps}m → 自适应eps={eps:.2f}m)") + else: + print(f" → {n_clusters} 个簇 (eps={eps}m)") + if n_noise > 0: + print(f" ⚠️ {n_noise} 个噪声点(孤立检测)") + + # 调试:显示每个簇的详细信息 + for cluster_id in sorted(set(cluster_labels)): + if cluster_id == -1: + continue + cluster_objs = [objects[i] for i, label in enumerate(cluster_labels) if label == cluster_id] + cluster_centers = [obj['center_3d'] for obj in cluster_objs] + cluster_views = [f"V{obj['view_idx']+1}" for obj in cluster_objs] + + # 计算簇内最大距离 + max_dist = 0 + if len(cluster_centers) > 1: + from scipy.spatial.distance import pdist + distances = pdist(np.array(cluster_centers)) + max_dist = distances.max() if len(distances) > 0 else 0 + + print(f" 簇 {cluster_id}: {len(cluster_objs)} 个检测 (来自视图: {', '.join(cluster_views)}, 最大簇内距离: {max_dist:.2f}m)") + + # 为每个簇创建一个全局物体 + cluster_to_global_id = {} + + for cluster_id in set(cluster_labels): + if cluster_id == -1: + # 噪声点,每个单独成为一个物体 + for i, label in enumerate(cluster_labels): + if label == -1: + obj = objects[i] + unique_objects.append({ + 'global_id': next_global_id, + 'label': obj['label'], + 'views': [(obj['view_idx'], obj['det_idx'])], + 'center_3d': obj['center_3d'], + }) + object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id + next_global_id += 1 + else: + # 正常簇 + cluster_objects = [objects[i] for i, label in enumerate(cluster_labels) if label == cluster_id] + + # 计算簇的中心(加权平均,权重为置信度) + total_conf = sum(o['confidence'] for o in cluster_objects) + weighted_center = sum(o['center_3d'] * o['confidence'] for o in cluster_objects) / total_conf + + # 创建全局物体 + unique_objects.append({ + 'global_id': next_global_id, + 'label': cluster_objects[0]['label'], + 'views': [(o['view_idx'], o['det_idx']) for o in cluster_objects], + 'center_3d': weighted_center, + }) + + # 映射所有检测到这个全局ID + for obj in cluster_objects: + object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id + + print(f" 簇 {cluster_id}: {len(cluster_objects)} 个检测合并") + + next_global_id += 1 + + print(f"\n 📊 总结:") + print(f" 总检测数: {sum(len(objs) for objs in objects_by_label.values())}") + print(f" 唯一物体: {len(unique_objects)}") + + # 打印匹配结果(按规范化标签统计) + label_counts = defaultdict(int) + original_labels = defaultdict(set) + for obj in unique_objects: + norm_label = normalize_label(obj['label']) + label_counts[norm_label] += 1 + original_labels[norm_label].add(obj['label']) + + print(f"\n 📊 物体类别统计(规范化后):") + for norm_label, count in sorted(label_counts.items()): + orig_labels = original_labels[norm_label] + if len(orig_labels) > 1: + print(f" {norm_label} (原标签: {', '.join(sorted(orig_labels))}): {count} 个") + else: + print(f" {norm_label}: {count} 个") + + return object_id_map, unique_objects + + +def create_multi_view_segmented_mesh(processed_data, all_view_detections, all_view_masks, + object_id_map, unique_objects, target_dir, use_sam=True): + """创建多视图融合的分割 mesh(使用 utils3d.image_mesh)""" + try: + print("\n🎨 生成多视图分割 mesh...") + + # 按物体类别(label)分配颜色,使用规范化标签避免组合标签问题 + # 获取所有不同的规范化类别 + unique_normalized_labels = sorted(set(normalize_label(obj['label']) for obj in unique_objects)) + label_colors = {} + colors = generate_distinct_colors(len(unique_normalized_labels)) + + # 为规范化标签分配颜色 + for i, norm_label in enumerate(unique_normalized_labels): + label_colors[norm_label] = colors[i] + + # 为每个唯一物体分配基于规范化类别的颜色 + for obj in unique_objects: + norm_label = normalize_label(obj['label']) + obj['color'] = label_colors[norm_label] + obj['normalized_label'] = norm_label # 保存规范化标签 + + # 打印类别-颜色映射(按规范化标签) + print(f" 物体类别颜色映射(规范化标签):") + for norm_label, color in sorted(label_colors.items()): + count = sum(1 for obj in unique_objects if normalize_label(obj['label']) == norm_label) + # 显示所有原始标签 + original_labels = set(obj['label'] for obj in unique_objects if normalize_label(obj['label']) == norm_label) + if len(original_labels) > 1: + print(f" {norm_label} (包含: {', '.join(sorted(original_labels))}) × {count} → RGB{color}") + else: + print(f" {norm_label} × {count} → RGB{color}") + + # 导入 utils3d + import utils3d + + all_meshes = [] + + # 为每个视图生成 mesh + for view_idx in range(len(processed_data)): + view_data = processed_data[view_idx] + image = view_data["image"] + points3d = view_data["points3d"] + mask = view_data.get("mask") + normal = view_data.get("normal") + + detections = all_view_detections[view_idx] + masks = all_view_masks[view_idx] + + if len(detections) == 0: + continue + + # 确保图像在 [0, 255] 范围 + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = (image * 255).astype(np.uint8) + else: + image = image.astype(np.uint8) + + # 创建彩色图像(使用置信度优先策略避免颜色混乱) + colored_image = image.copy() + confidence_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32) # 记录每个像素的置信度 + + # 收集所有检测及其信息(应用质量过滤) + detections_info = [] + filtered_count = 0 + for det_idx, (det, seg_mask) in enumerate(zip(detections, masks)): + # 过滤低置信度检测 + if det['confidence'] < MIN_DETECTION_CONFIDENCE: + filtered_count += 1 + continue + + # 过滤过小的mask + mask_area = seg_mask.sum() + if mask_area < MIN_MASK_AREA: + filtered_count += 1 + continue + + global_id = object_id_map[view_idx].get(det_idx) + if global_id is None: + continue + + unique_obj = next((obj for obj in unique_objects if obj['global_id'] == global_id), None) + if unique_obj is None: + continue + + detections_info.append({ + 'mask': seg_mask, + 'color': unique_obj['color'], + 'confidence': det['confidence'], + 'label': det['label'], + 'area': mask_area + }) + + if filtered_count > 0: + print(f" 视图 {view_idx + 1}: 过滤了 {filtered_count} 个低质量检测") + + # 按置信度排序(从低到高),这样高置信度的会最后写入 + detections_info.sort(key=lambda x: x['confidence']) + + # 应用颜色(置信度高的优先) + for info in detections_info: + seg_mask = info['mask'] + color = info['color'] + conf = info['confidence'] + + # 只在当前置信度更高的地方覆盖 + update_mask = seg_mask & (conf > confidence_map) + colored_image[update_mask] = color + confidence_map[update_mask] = conf + + # 使用 utils3d.image_mesh 生成 mesh + height, width = image.shape[:2] + + if normal is None: + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points3d, + colored_image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=mask if mask is not None else np.ones((height, width), dtype=bool), + tri=True + ) + vertex_normals = None + else: + faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh( + points3d, + colored_image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + normal, + mask=mask if mask is not None else np.ones((height, width), dtype=bool), + tri=True + ) + + # 坐标变换 + vertices = vertices * np.array([1, -1, -1], dtype=np.float32) + if vertex_normals is not None: + vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32) + + # 创建 mesh + view_mesh = trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_normals=vertex_normals, + vertex_colors=(vertex_colors * 255).astype(np.uint8), + process=False + ) + + all_meshes.append(view_mesh) + print(f" 视图 {view_idx + 1}: {len(vertices):,} 顶点, {len(faces):,} 面") + + if len(all_meshes) == 0: + print("⚠️ 未生成任何 mesh") + return None + + # 融合所有 mesh + print(" 融合所有视图...") + combined_mesh = trimesh.util.concatenate(all_meshes) + + # 保存 + glb_path = os.path.join(target_dir, 'multi_view_segmented_mesh.glb') + combined_mesh.export(glb_path) + + print(f"✅ 多视图分割 mesh 已保存: {glb_path}") + print(f" 总计: {len(combined_mesh.vertices):,} 顶点, {len(combined_mesh.faces):,} 面") + print(f" {len(unique_objects)} 个唯一物体") + + return glb_path + + except Exception as e: + print(f"❌ 生成多视图 mesh 失败: {e}") + import traceback + traceback.print_exc() + return None + + +def create_segmented_pointcloud(processed_data, detections, masks, target_dir, use_sam=True): + """创建分割点云(单视图,仅用于兼容)""" + if len(detections) == 0: + return None + + try: + print(f"🎨 生成分割点云...") + + # 使用第一个视图 + first_view = processed_data[0] + image = first_view["image"] + points3d = first_view["points3d"] + normal = first_view.get("normal") + mask = first_view.get("mask") + + # 确保图像在 [0, 255] 范围 + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = (image * 255).astype(np.uint8) + else: + image = image.astype(np.uint8) + + # 生成颜色 + distinct_colors = generate_distinct_colors(len(detections)) + + # 创建彩色图像 + colored_image = image.copy() + + for i, (det, seg_mask) in enumerate(zip(detections, masks)): + color = distinct_colors[i] + colored_image[seg_mask] = color + print(f" {det['label']} → RGB{color}") + + # 生成点云(使用 MapAnything 的方法) + height, width = image.shape[:2] + + # 简单方法:直接从 points3d 生成顶点颜色 + vertices = points3d.reshape(-1, 3) + colors = (colored_image.astype(np.float32) / 255.0).reshape(-1, 3) + + if mask is not None: + valid_mask = mask.reshape(-1) + vertices = vertices[valid_mask] + colors = colors[valid_mask] + + # 坐标变换 + vertices = vertices * np.array([1, -1, -1], dtype=np.float32) + + # 创建点云 + pointcloud = trimesh.PointCloud( + vertices=vertices, + colors=(colors * 255).astype(np.uint8) + ) + + # 保存 + seg_glb_path = os.path.join(target_dir, 'segmented_pointcloud.glb') + pointcloud.export(seg_glb_path) + + print(f"✅ 分割点云已保存: {seg_glb_path}") + return seg_glb_path + + except Exception as e: + print(f"❌ 生成分割点云失败: {e}") + import traceback + traceback.print_exc() + return None + + +# ============================================================================ +# 核心模型推理 +# ============================================================================ + +@spaces.GPU(duration=120) +def run_model( + target_dir, + apply_mask=True, + mask_edges=True, + filter_black_bg=False, + filter_white_bg=False, + enable_segmentation=False, + text_prompt=DEFAULT_TEXT_PROMPT, + use_sam=True, +): + """ + Run the MapAnything model + GroundingDINO + SAM segmentation + """ + global model, grounding_dino_model, sam_predictor + import torch + + print(f"处理图像: {target_dir}") + + # 设备检查 + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # 初始化 MapAnything 模型 - 从 HuggingFace + if model is None: + print("📥 从 HuggingFace 加载 MapAnything...") + model = initialize_mapanything_model(high_level_config, device) + print("✅ MapAnything 加载成功") + else: + model = model.to(device) + + model.eval() + + # 加载分割模型 + if enable_segmentation: + load_grounding_dino_model(device) + if use_sam: + load_sam_model(device) + + # 加载图像 + print("加载图像...") + image_folder_path = os.path.join(target_dir, "images") + views = load_images(image_folder_path) + + print(f"加载了 {len(views)} 张图像") + if len(views) == 0: + raise ValueError("未找到图像") + + # 运行 MapAnything 推理 + print("运行 3D 重建...") + outputs = model.infer( + views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False + ) + + # 转换预测结果 + predictions = {} + extrinsic_list = [] + intrinsic_list = [] + world_points_list = [] + depth_maps_list = [] + images_list = [] + final_mask_list = [] + confidences = [] + + for pred in outputs: + depthmap_torch = pred["depth_z"][0].squeeze(-1) + intrinsics_torch = pred["intrinsics"][0] + camera_pose_torch = pred["camera_poses"][0] + conf = pred["conf"][0].squeeze(-1) + + pts3d_computed, valid_mask = depthmap_to_world_frame( + depthmap_torch, intrinsics_torch, camera_pose_torch + ) + + if "mask" in pred: + mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) + else: + mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool) + + mask = mask & valid_mask.cpu().numpy() + image = pred["img_no_norm"][0].cpu().numpy() + + extrinsic_list.append(camera_pose_torch.cpu().numpy()) + intrinsic_list.append(intrinsics_torch.cpu().numpy()) + world_points_list.append(pts3d_computed.cpu().numpy()) + depth_maps_list.append(depthmap_torch.cpu().numpy()) + images_list.append(image) + final_mask_list.append(mask) + confidences.append(conf.cpu().numpy()) + + predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) + predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) + predictions["world_points"] = np.stack(world_points_list, axis=0) + predictions["conf"] = np.stack(confidences, axis=0) + + depth_maps = np.stack(depth_maps_list, axis=0) + if len(depth_maps.shape) == 3: + depth_maps = depth_maps[..., np.newaxis] + predictions["depth"] = depth_maps + + predictions["images"] = np.stack(images_list, axis=0) + predictions["final_mask"] = np.stack(final_mask_list, axis=0) + + # 处理可视化数据 + processed_data = process_predictions_for_visualization( + predictions, views, high_level_config, filter_black_bg, filter_white_bg + ) + + # 多视图分割处理 + segmented_glb = None + if enable_segmentation and grounding_dino_model is not None: + print("\n🎯 开始多视图分割...") + print(f"🔍 使用检测提示: {text_prompt[:100]}...") + + all_view_detections = [] + all_view_masks = [] + + # 对每个视图进行分割 + for view_idx, ref_image in enumerate(images_list): + print(f"\n📸 处理视图 {view_idx + 1}/{len(images_list)}...") + + if ref_image.dtype != np.uint8: + ref_image_np = (ref_image * 255).astype(np.uint8) + else: + ref_image_np = ref_image + + # GroundingDINO 检测 + detections = run_grounding_dino_detection(ref_image_np, text_prompt, device) + + if len(detections) > 0: + # SAM 精确分割 + boxes = [d['bbox'] for d in detections] + masks = run_sam_refinement(ref_image_np, boxes) if use_sam else [] + + # 获取3D点云和encoder(用于特征提取) + points3d = world_points_list[view_idx] + encoder = model.encoder if hasattr(model, 'encoder') else None + + # V5: 为每个检测物体提取多种特征 + for det_idx, (det, mask) in enumerate(zip(detections, masks)): + # 1. 计算3D中心点 + center_3d = compute_object_3d_center(points3d, mask) + det['center_3d'] = center_3d + + # 2. 计算3D边界框 + if center_3d is not None: + masked_points = points3d[mask] + if len(masked_points) > 0: + bbox_min = masked_points.min(axis=0) + bbox_max = masked_points.max(axis=0) + bbox_size = bbox_max - bbox_min + det['bbox_3d'] = { + 'center': center_3d, + 'size': bbox_size, + 'min': bbox_min, + 'max': bbox_max + } + + # 3. 存储2D mask(用于IoU计算) + det['mask_2d'] = mask + + # 4. 提取视觉特征(DINOv2)- 可选 + if ENABLE_VISUAL_FEATURES and encoder is not None: + visual_feat = extract_visual_features(ref_image, mask, encoder) + det['visual_feature'] = visual_feat + else: + det['visual_feature'] = None + + all_view_detections.append(detections) + all_view_masks.append(masks) + else: + all_view_detections.append([]) + all_view_masks.append([]) + + # 跨视图匹配物体 + if any(len(dets) > 0 for dets in all_view_detections): + object_id_map, unique_objects = match_objects_across_views(all_view_detections) + + # 生成多视图分割 mesh + segmented_glb = create_multi_view_segmented_mesh( + processed_data, all_view_detections, all_view_masks, + object_id_map, unique_objects, target_dir, use_sam + ) + + # 清理 + torch.cuda.empty_cache() + + return predictions, processed_data, segmented_glb + + +# ============================================================================ +# 从 gradio_app.py 复制的其他函数 +# ============================================================================ + +def update_view_selectors(processed_data): + """Update view selector dropdowns based on available views""" + if processed_data is None or len(processed_data) == 0: + choices = ["View 1"] + else: + num_views = len(processed_data) + choices = [f"View {i + 1}" for i in range(num_views)] + + return ( + gr.Dropdown(choices=choices, value=choices[0]), + gr.Dropdown(choices=choices, value=choices[0]), + gr.Dropdown(choices=choices, value=choices[0]), + ) + + +def get_view_data_by_index(processed_data, view_index): + """Get view data by index, handling bounds""" + if processed_data is None or len(processed_data) == 0: + return None + + view_keys = list(processed_data.keys()) + if view_index < 0 or view_index >= len(view_keys): + view_index = 0 + + return processed_data[view_keys[view_index]] + + +def update_depth_view(processed_data, view_index): + """Update depth view for a specific view index""" + view_data = get_view_data_by_index(processed_data, view_index) + if view_data is None or view_data["depth"] is None: + return None + + return colorize_depth(view_data["depth"], mask=view_data.get("mask")) + + +def update_normal_view(processed_data, view_index): + """Update normal view for a specific view index""" + view_data = get_view_data_by_index(processed_data, view_index) + if view_data is None or view_data["normal"] is None: + return None + + return colorize_normal(view_data["normal"], mask=view_data.get("mask")) + + +def update_measure_view(processed_data, view_index): + """Update measure view for a specific view index with mask overlay""" + view_data = get_view_data_by_index(processed_data, view_index) + if view_data is None: + return None, [] + + image = view_data["image"].copy() + + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = (image * 255).astype(np.uint8) + else: + image = image.astype(np.uint8) + + if view_data["mask"] is not None: + mask = view_data["mask"] + invalid_mask = ~mask + + if invalid_mask.any(): + overlay_color = np.array([255, 220, 220], dtype=np.uint8) + alpha = 0.5 + for c in range(3): + image[:, :, c] = np.where( + invalid_mask, + (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], + image[:, :, c], + ).astype(np.uint8) + + return image, [] + + +def navigate_depth_view(processed_data, current_selector_value, direction): + """Navigate depth view""" + if processed_data is None or len(processed_data) == 0: + return "View 1", None + + try: + current_view = int(current_selector_value.split()[1]) - 1 + except: + current_view = 0 + + num_views = len(processed_data) + new_view = (current_view + direction) % num_views + + new_selector_value = f"View {new_view + 1}" + depth_vis = update_depth_view(processed_data, new_view) + + return new_selector_value, depth_vis + + +def navigate_normal_view(processed_data, current_selector_value, direction): + """Navigate normal view""" + if processed_data is None or len(processed_data) == 0: + return "View 1", None + + try: + current_view = int(current_selector_value.split()[1]) - 1 + except: + current_view = 0 + + num_views = len(processed_data) + new_view = (current_view + direction) % num_views + + new_selector_value = f"View {new_view + 1}" + normal_vis = update_normal_view(processed_data, new_view) + + return new_selector_value, normal_vis + + +def navigate_measure_view(processed_data, current_selector_value, direction): + """Navigate measure view""" + if processed_data is None or len(processed_data) == 0: + return "View 1", None, [] + + try: + current_view = int(current_selector_value.split()[1]) - 1 + except: + current_view = 0 + + num_views = len(processed_data) + new_view = (current_view + direction) % num_views + + new_selector_value = f"View {new_view + 1}" + measure_image, measure_points = update_measure_view(processed_data, new_view) + + return new_selector_value, measure_image, measure_points + + +def populate_visualization_tabs(processed_data): + """Populate the depth, normal, and measure tabs with processed data""" + if processed_data is None or len(processed_data) == 0: + return None, None, None, [] + + depth_vis = update_depth_view(processed_data, 0) + normal_vis = update_normal_view(processed_data, 0) + measure_img, _ = update_measure_view(processed_data, 0) + + return depth_vis, normal_vis, measure_img, [] + + +def handle_uploads(input_video, input_images, s_time_interval=1.0): + """Handle uploaded video/images""" + start_time = time.time() + gc.collect() + torch.cuda.empty_cache() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + target_dir = f"input_images_{timestamp}" + target_dir_images = os.path.join(target_dir, "images") + + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + os.makedirs(target_dir) + os.makedirs(target_dir_images) + + image_paths = [] + + # Handle images + if input_images is not None: + for file_data in input_images: + if isinstance(file_data, dict) and "name" in file_data: + file_path = file_data["name"] + else: + file_path = file_data + + file_ext = os.path.splitext(file_path)[1].lower() + if file_ext in [".heic", ".heif"]: + try: + with Image.open(file_path) as img: + if img.mode not in ("RGB", "L"): + img = img.convert("RGB") + base_name = os.path.splitext(os.path.basename(file_path))[0] + dst_path = os.path.join(target_dir_images, f"{base_name}.jpg") + img.save(dst_path, "JPEG", quality=95) + image_paths.append(dst_path) + except Exception as e: + print(f"Error converting HEIC: {e}") + dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) + shutil.copy(file_path, dst_path) + image_paths.append(dst_path) + else: + dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) + shutil.copy(file_path, dst_path) + image_paths.append(dst_path) + + # Handle video + if input_video is not None: + if isinstance(input_video, dict) and "name" in input_video: + video_path = input_video["name"] + else: + video_path = input_video + + vs = cv2.VideoCapture(video_path) + fps = vs.get(cv2.CAP_PROP_FPS) + frame_interval = int(fps * s_time_interval) + + count = 0 + video_frame_num = 0 + while True: + gotit, frame = vs.read() + if not gotit: + break + count += 1 + if count % frame_interval == 0: + image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png") + cv2.imwrite(image_path, frame) + image_paths.append(image_path) + video_frame_num += 1 + + image_paths = sorted(image_paths) + + end_time = time.time() + print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds") + return target_dir, image_paths + + +def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0): + """Update gallery on upload""" + if not input_video and not input_images: + return None, None, None, None, None + target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval) + return ( + None, + None, + target_dir, + image_paths, + "上传完成,点击「重建」开始 3D 处理", + ) + + +@spaces.GPU(duration=120) +def gradio_demo( + target_dir, + frame_filter="All", + show_cam=True, + filter_black_bg=False, + filter_white_bg=False, + conf_thres=3.0, + apply_mask=True, + show_mesh=True, + enable_segmentation=False, + text_prompt=DEFAULT_TEXT_PROMPT, + use_sam=True, +): + """Perform reconstruction""" + if not os.path.isdir(target_dir) or target_dir == "None": + return None, None, "请先上传文件", None, None, None, None, None, None, None, None + + start_time = time.time() + gc.collect() + torch.cuda.empty_cache() + + target_dir_images = os.path.join(target_dir, "images") + all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] + all_files_display = [f"{i}: {filename}" for i, filename in enumerate(all_files)] + frame_filter_choices = ["All"] + all_files_display + + print("运行 MapAnything 模型...") + with torch.no_grad(): + predictions, processed_data, segmented_glb = run_model( + target_dir, apply_mask, True, filter_black_bg, filter_white_bg, + enable_segmentation, text_prompt, use_sam + ) + + # 保存预测结果 + prediction_save_path = os.path.join(target_dir, "predictions.npz") + np.savez(prediction_save_path, **predictions) + + if frame_filter is None: + frame_filter = "All" + + # 生成原始 GLB + glbfile = os.path.join( + target_dir, + f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}.glb", + ) + + glbscene = predictions_to_glb( + predictions, + filter_by_frames=frame_filter, + show_cam=show_cam, + mask_black_bg=filter_black_bg, + mask_white_bg=filter_white_bg, + as_mesh=show_mesh, + conf_percentile=conf_thres, + ) + glbscene.export(file_obj=glbfile) + + # 清理 + del predictions + gc.collect() + torch.cuda.empty_cache() + + end_time = time.time() + print(f"总耗时: {end_time - start_time:.2f}秒") + log_msg = f"✅ 重建成功 ({len(all_files)} 帧)" + + # 填充可视化标签 + depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(processed_data) + + # 更新视图选择器 + depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data) + + return ( + glbfile, + segmented_glb, + log_msg, + gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), + processed_data, + depth_vis, + normal_vis, + measure_img, + "", + depth_selector, + normal_selector, + measure_selector, + ) + + +def colorize_depth(depth_map, mask=None): + """Convert depth map to colorized visualization""" + if depth_map is None: + return None + + depth_normalized = depth_map.copy() + valid_mask = depth_normalized > 0 + + if mask is not None: + valid_mask = valid_mask & mask + + if valid_mask.sum() > 0: + valid_depths = depth_normalized[valid_mask] + p5 = np.percentile(valid_depths, 5) + p95 = np.percentile(valid_depths, 95) + depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) + + import matplotlib.pyplot as plt + colormap = plt.cm.turbo_r + colored = colormap(depth_normalized) + colored = (colored[:, :, :3] * 255).astype(np.uint8) + colored[~valid_mask] = [255, 255, 255] + + return colored + + +def colorize_normal(normal_map, mask=None): + """Convert normal map to colorized visualization""" + if normal_map is None: + return None + + normal_vis = normal_map.copy() + + if mask is not None: + invalid_mask = ~mask + normal_vis[invalid_mask] = [0, 0, 0] + + normal_vis = (normal_vis + 1.0) / 2.0 + normal_vis = (normal_vis * 255).astype(np.uint8) + + return normal_vis + + +def process_predictions_for_visualization( + predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False +): + """Extract depth, normal, and 3D points from predictions for visualization""" + processed_data = {} + + for view_idx, view in enumerate(views): + image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) + pred_pts3d = predictions["world_points"][view_idx] + + view_data = { + "image": image[0], + "points3d": pred_pts3d, + "depth": None, + "normal": None, + "mask": None, + } + + mask = predictions["final_mask"][view_idx].copy() + + if filter_black_bg: + view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] + black_bg_mask = view_colors.sum(axis=2) >= 16 + mask = mask & black_bg_mask + + if filter_white_bg: + view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] + white_bg_mask = ~( + (view_colors[:, :, 0] > 240) + & (view_colors[:, :, 1] > 240) + & (view_colors[:, :, 2] > 240) + ) + mask = mask & white_bg_mask + + view_data["mask"] = mask + view_data["depth"] = predictions["depth"][view_idx].squeeze() + + normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"]) + view_data["normal"] = normals + + processed_data[view_idx] = view_data + + return processed_data + + +def reset_measure(processed_data): + """Reset measure points""" + if processed_data is None or len(processed_data) == 0: + return None, [], "" + first_view = list(processed_data.values())[0] + return first_view["image"], [], "" + + +def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData): + """Handle measurement on images""" + try: + if processed_data is None or len(processed_data) == 0: + return None, [], "没有可用数据" + + try: + current_view_index = int(current_view_selector.split()[1]) - 1 + except: + current_view_index = 0 + + if current_view_index < 0 or current_view_index >= len(processed_data): + current_view_index = 0 + + view_keys = list(processed_data.keys()) + current_view = processed_data[view_keys[current_view_index]] + + if current_view is None: + return None, [], "没有视图数据" + + point2d = event.index[0], event.index[1] + + if ( + current_view["mask"] is not None + and 0 <= point2d[1] < current_view["mask"].shape[0] + and 0 <= point2d[0] < current_view["mask"].shape[1] + ): + if not current_view["mask"][point2d[1], point2d[0]]: + masked_image, _ = update_measure_view(processed_data, current_view_index) + return ( + masked_image, + measure_points, + '无法在遮罩区域测量(显示为灰色)', + ) + + measure_points.append(point2d) + + image, _ = update_measure_view(processed_data, current_view_index) + if image is None: + return None, [], "没有可用图像" + + image = image.copy() + points3d = current_view["points3d"] + + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = (image * 255).astype(np.uint8) + else: + image = image.astype(np.uint8) + + for p in measure_points: + if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: + image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2) + + depth_text = "" + for i, p in enumerate(measure_points): + if ( + current_view["depth"] is not None + and 0 <= p[1] < current_view["depth"].shape[0] + and 0 <= p[0] < current_view["depth"].shape[1] + ): + d = current_view["depth"][p[1], p[0]] + depth_text += f"- **P{i + 1} 深度: {d:.2f}m.**\n" + else: + if ( + points3d is not None + and 0 <= p[1] < points3d.shape[0] + and 0 <= p[0] < points3d.shape[1] + ): + z = points3d[p[1], p[0], 2] + depth_text += f"- **P{i + 1} Z坐标: {z:.2f}m.**\n" + + if len(measure_points) == 2: + point1, point2 = measure_points + if ( + 0 <= point1[0] < image.shape[1] + and 0 <= point1[1] < image.shape[0] + and 0 <= point2[0] < image.shape[1] + and 0 <= point2[1] < image.shape[0] + ): + image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2) + + distance_text = "- **距离: 无法计算**" + if ( + points3d is not None + and 0 <= point1[1] < points3d.shape[0] + and 0 <= point1[0] < points3d.shape[1] + and 0 <= point2[1] < points3d.shape[0] + and 0 <= point2[0] < points3d.shape[1] + ): + try: + p1_3d = points3d[point1[1], point1[0]] + p2_3d = points3d[point2[1], point2[0]] + distance = np.linalg.norm(p1_3d - p2_3d) + distance_text = f"- **距离: {distance:.2f}m**" + except Exception as e: + distance_text = f"- **距离计算错误: {e}**" + + measure_points = [] + text = depth_text + distance_text + return [image, measure_points, text] + else: + return [image, measure_points, depth_text] + + except Exception as e: + print(f"测量错误: {e}") + return None, [], f"测量错误: {e}" + + +def clear_fields(): + """Clear 3D viewer""" + return None, None + + +def update_log(): + """Display log message""" + return "加载和重建中..." + + +def update_visualization( + target_dir, + frame_filter, + show_cam, + is_example, + conf_thres=None, + filter_black_bg=False, + filter_white_bg=False, + show_mesh=True, +): + """Update visualization""" + if is_example == "True": + return gr.update(), "没有可用的重建。请先点击重建按钮。" + + if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): + return gr.update(), "没有可用的重建。请先点击重建按钮。" + + predictions_path = os.path.join(target_dir, "predictions.npz") + if not os.path.exists(predictions_path): + return gr.update(), f"没有可用的重建。请先运行「重建」。" + + loaded = np.load(predictions_path, allow_pickle=True) + predictions = {key: loaded[key] for key in loaded.keys()} + + glbfile = os.path.join( + target_dir, + f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", + ) + + glbscene = predictions_to_glb( + predictions, + filter_by_frames=frame_filter, + show_cam=show_cam, + mask_black_bg=filter_black_bg, + mask_white_bg=filter_white_bg, + as_mesh=show_mesh, + conf_percentile=conf_thres, + ) + glbscene.export(file_obj=glbfile) + + return glbfile, "可视化已更新。" + + +def update_all_views_on_filter_change( + target_dir, + filter_black_bg, + filter_white_bg, + processed_data, + depth_view_selector, + normal_view_selector, + measure_view_selector, +): + """Update all views on filter change""" + if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): + return processed_data, None, None, None, [] + + predictions_path = os.path.join(target_dir, "predictions.npz") + if not os.path.exists(predictions_path): + return processed_data, None, None, None, [] + + try: + loaded = np.load(predictions_path, allow_pickle=True) + predictions = {key: loaded[key] for key in loaded.keys()} + + image_folder_path = os.path.join(target_dir, "images") + views = load_images(image_folder_path) + + new_processed_data = process_predictions_for_visualization( + predictions, views, high_level_config, filter_black_bg, filter_white_bg + ) + + try: + depth_view_idx = int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0 + except: + depth_view_idx = 0 + + try: + normal_view_idx = int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0 + except: + normal_view_idx = 0 + + try: + measure_view_idx = int(measure_view_selector.split()[1]) - 1 if measure_view_selector else 0 + except: + measure_view_idx = 0 + + depth_vis = update_depth_view(new_processed_data, depth_view_idx) + normal_vis = update_normal_view(new_processed_data, normal_view_idx) + measure_img, _ = update_measure_view(new_processed_data, measure_view_idx) + + return new_processed_data, depth_vis, normal_vis, measure_img, [] + + except Exception as e: + print(f"更新视图失败: {e}") + return processed_data, None, None, None, [] + + +# ============================================================================ +# 示例场景 +# ============================================================================ + +def get_scene_info(examples_dir): + """Get information about scenes in the examples directory""" + import glob + + scenes = [] + if not os.path.exists(examples_dir): + return scenes + + for scene_folder in sorted(os.listdir(examples_dir)): + scene_path = os.path.join(examples_dir, scene_folder) + if os.path.isdir(scene_path): + image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] + image_files = [] + for ext in image_extensions: + image_files.extend(glob.glob(os.path.join(scene_path, ext))) + image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) + + if image_files: + image_files = sorted(image_files) + first_image = image_files[0] + num_images = len(image_files) + + scenes.append( + { + "name": scene_folder, + "path": scene_path, + "thumbnail": first_image, + "num_images": num_images, + "image_files": image_files, + } + ) + + return scenes + + +def load_example_scene(scene_name, examples_dir="examples"): + """Load a scene from examples directory""" + scenes = get_scene_info(examples_dir) + + selected_scene = None + for scene in scenes: + if scene["name"] == scene_name: + selected_scene = scene + break + + if selected_scene is None: + return None, None, None, None, "场景未找到" + + target_dir, image_paths = handle_uploads(None, selected_scene["image_files"]) + + return ( + None, + None, + target_dir, + image_paths, + f"已加载场景 '{scene_name}' ({selected_scene['num_images']} 张图像)。点击「重建」开始 3D 处理。", + ) + + +# ============================================================================ +# Gradio UI +# ============================================================================ + +theme = get_gradio_theme() + +# 自定义CSS防止UI抖动 +CUSTOM_CSS = GRADIO_CSS + """ +/* 防止组件撑开布局 */ +.gradio-container { + max-width: 100% !important; +} + +/* 固定Gallery高度 */ +.gallery-container { + max-height: 350px !important; + overflow-y: auto !important; +} + +/* 固定File组件高度 */ +.file-preview { + max-height: 200px !important; + overflow-y: auto !important; +} + +/* 固定Video组件高度 */ +.video-container { + max-height: 300px !important; +} + +/* 防止Textbox无限扩展 */ +.textbox-container { + max-height: 100px !important; +} + +/* 保持Tabs内容区域稳定 */ +.tab-content { + min-height: 550px !important; +} +""" + +with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V8 - 3D重建与物体分割") as demo: + is_example = gr.Textbox(label="is_example", visible=False, value="None") + processed_data_state = gr.State(value=None) + measure_points_state = gr.State(value=[]) + + # 顶部标题 + gr.HTML(""" +
+

MapAnything V8 - 3D重建与物体分割

+

基于DBSCAN聚类的智能物体识别 | 多视图融合 | 自适应参数调整

+
+ """) + + target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") + + with gr.Row(equal_height=False): + # 左侧:输入区域 + with gr.Column(scale=1, min_width=300): + gr.Markdown("### 📤 输入") + + with gr.Tabs(): + with gr.Tab("📷 图片"): + input_images = gr.File( + file_count="multiple", + label="上传多张图片(推荐3-10张)", + interactive=True, + height=200 + ) + + with gr.Tab("🎥 视频"): + input_video = gr.Video( + label="上传视频", + interactive=True, + height=300 + ) + s_time_interval = gr.Slider( + minimum=0.1, maximum=5.0, value=1.0, step=0.1, + label="帧采样间隔(秒)", interactive=True + ) + + image_gallery = gr.Gallery( + label="图片预览", columns=3, height=350, + show_download_button=True, object_fit="contain", preview=True + ) + + with gr.Row(): + submit_btn = gr.Button("🚀 开始重建", variant="primary", scale=2) + clear_btn = gr.ClearButton( + [input_video, input_images, target_dir_output, image_gallery], + value="🗑️ 清空", scale=1 + ) + + # 右侧:输出区域 + with gr.Column(scale=2, min_width=600): + gr.Markdown("### 🎯 输出") + + with gr.Tabs(): + with gr.Tab("🏗️ 原始3D"): + reconstruction_output = gr.Model3D( + height=550, zoom_speed=0.5, pan_speed=0.5, + clear_color=[0.0, 0.0, 0.0, 0.0] + ) + + with gr.Tab("🎨 分割3D"): + segmented_output = gr.Model3D( + height=550, zoom_speed=0.5, pan_speed=0.5, + clear_color=[0.0, 0.0, 0.0, 0.0] + ) + + with gr.Tab("📊 深度图"): + with gr.Row(elem_classes=["navigation-row"]): + prev_depth_btn = gr.Button("◀", size="sm", scale=1) + depth_view_selector = gr.Dropdown( + choices=["View 1"], value="View 1", + label="视图", scale=3, interactive=True + ) + next_depth_btn = gr.Button("▶", size="sm", scale=1) + depth_map = gr.Image( + type="numpy", label="", format="png", interactive=False, + height=500 + ) + + with gr.Tab("🧭 法线图"): + with gr.Row(elem_classes=["navigation-row"]): + prev_normal_btn = gr.Button("◀", size="sm", scale=1) + normal_view_selector = gr.Dropdown( + choices=["View 1"], value="View 1", + label="视图", scale=3, interactive=True + ) + next_normal_btn = gr.Button("▶", size="sm", scale=1) + normal_map = gr.Image( + type="numpy", label="", format="png", interactive=False, + height=500 + ) + + with gr.Tab("📏 测量"): + gr.Markdown("**点击图片两次进行距离测量**") + with gr.Row(elem_classes=["navigation-row"]): + prev_measure_btn = gr.Button("◀", size="sm", scale=1) + measure_view_selector = gr.Dropdown( + choices=["View 1"], value="View 1", + label="视图", scale=3, interactive=True + ) + next_measure_btn = gr.Button("▶", size="sm", scale=1) + measure_image = gr.Image( + type="numpy", show_label=False, + format="webp", interactive=False, sources=[], + height=500 + ) + measure_text = gr.Markdown("") + + log_output = gr.Textbox( + value="📌 请上传图片或视频,然后点击「开始重建」", + label="状态信息", + interactive=False, + lines=1, + max_lines=1 + ) + + # 高级选项(可折叠) + with gr.Accordion("⚙️ 高级选项", open=False): + with gr.Row(equal_height=False): + with gr.Column(scale=1, min_width=300): + gr.Markdown("#### 可视化参数") + frame_filter = gr.Dropdown( + choices=["All"], value="All", label="显示帧" + ) + conf_thres = gr.Slider( + minimum=0, maximum=100, value=0, step=0.1, + label="置信度阈值(百分位)" + ) + show_cam = gr.Checkbox(label="显示相机", value=True) + show_mesh = gr.Checkbox(label="显示网格", value=True) + filter_black_bg = gr.Checkbox(label="过滤黑色背景", value=False) + filter_white_bg = gr.Checkbox(label="过滤白色背景", value=False) + + with gr.Column(scale=1, min_width=300): + gr.Markdown("#### 重建参数") + apply_mask_checkbox = gr.Checkbox( + label="应用深度掩码", value=True + ) + + gr.Markdown("#### 分割参数") + enable_segmentation = gr.Checkbox( + label="启用语义分割", value=False + ) + use_sam_checkbox = gr.Checkbox( + label="使用SAM精确分割", value=True + ) + + text_prompt = gr.Textbox( + value=DEFAULT_TEXT_PROMPT, + label="检测物体(用 . 分隔)", + placeholder="例如: chair . table . sofa", + lines=2, + max_lines=2 + ) + + with gr.Row(): + detect_all_btn = gr.Button("🔍 检测所有", size="sm") + restore_default_btn = gr.Button("↻ 默认", size="sm") + + # 示例场景(可折叠) + with gr.Accordion("🖼️ 示例场景", open=False): + scenes = get_scene_info("examples") + if scenes: + for i in range(0, len(scenes), 4): + with gr.Row(equal_height=True): + for j in range(4): + scene_idx = i + j + if scene_idx < len(scenes): + scene = scenes[scene_idx] + with gr.Column(scale=1, min_width=150): + scene_img = gr.Image( + value=scene["thumbnail"], + height=150, + interactive=False, + show_label=False, + sources=[], + container=False + ) + gr.Markdown( + f"**{scene['name']}** ({scene['num_images']}张)", + elem_classes=["text-center"] + ) + scene_img.select( + fn=lambda name=scene["name"]: load_example_scene(name), + outputs=[ + reconstruction_output, segmented_output, + target_dir_output, image_gallery, log_output + ] + ) + + # === 事件绑定 === + + # 分割选项按钮 + detect_all_btn.click( + fn=lambda: COMMON_OBJECTS_PROMPT, + outputs=[text_prompt] + ) + restore_default_btn.click( + fn=lambda: DEFAULT_TEXT_PROMPT, + outputs=[text_prompt] + ) + + # 上传文件自动更新 + input_video.change( + fn=update_gallery_on_upload, + inputs=[input_video, input_images, s_time_interval], + outputs=[reconstruction_output, segmented_output, target_dir_output, image_gallery, log_output] + ) + input_images.change( + fn=update_gallery_on_upload, + inputs=[input_video, input_images, s_time_interval], + outputs=[reconstruction_output, segmented_output, target_dir_output, image_gallery, log_output] + ) + + # 重建按钮 + submit_btn.click( + fn=clear_fields, + outputs=[reconstruction_output, segmented_output] + ).then( + fn=update_log, + outputs=[log_output] + ).then( + fn=gradio_demo, + inputs=[ + target_dir_output, frame_filter, show_cam, + filter_black_bg, filter_white_bg, conf_thres, + apply_mask_checkbox, show_mesh, + enable_segmentation, text_prompt, use_sam_checkbox + ], + outputs=[ + reconstruction_output, segmented_output, log_output, frame_filter, + processed_data_state, depth_map, normal_map, measure_image, + measure_text, depth_view_selector, normal_view_selector, measure_view_selector + ] + ).then( + fn=lambda: "False", + outputs=[is_example] + ) + + # 清空按钮 + clear_btn.add([reconstruction_output, segmented_output, log_output]) + + # 可视化参数实时更新 + for component in [frame_filter, show_cam, conf_thres, show_mesh]: + component.change( + fn=update_visualization, + inputs=[ + target_dir_output, frame_filter, show_cam, is_example, + conf_thres, filter_black_bg, filter_white_bg, show_mesh + ], + outputs=[reconstruction_output, log_output] + ) + + # 背景过滤器更新所有视图 + for bg_filter in [filter_black_bg, filter_white_bg]: + bg_filter.change( + fn=update_all_views_on_filter_change, + inputs=[ + target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, + depth_view_selector, normal_view_selector, measure_view_selector + ], + outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state] + ) + + # 深度图导航 + prev_depth_btn.click( + fn=lambda pd, cs: navigate_depth_view(pd, cs, -1), + inputs=[processed_data_state, depth_view_selector], + outputs=[depth_view_selector, depth_map] + ) + next_depth_btn.click( + fn=lambda pd, cs: navigate_depth_view(pd, cs, 1), + inputs=[processed_data_state, depth_view_selector], + outputs=[depth_view_selector, depth_map] + ) + depth_view_selector.change( + fn=lambda pd, sv: update_depth_view(pd, int(sv.split()[1]) - 1) if sv else None, + inputs=[processed_data_state, depth_view_selector], + outputs=[depth_map] + ) + + # 法线图导航 + prev_normal_btn.click( + fn=lambda pd, cs: navigate_normal_view(pd, cs, -1), + inputs=[processed_data_state, normal_view_selector], + outputs=[normal_view_selector, normal_map] + ) + next_normal_btn.click( + fn=lambda pd, cs: navigate_normal_view(pd, cs, 1), + inputs=[processed_data_state, normal_view_selector], + outputs=[normal_view_selector, normal_map] + ) + normal_view_selector.change( + fn=lambda pd, sv: update_normal_view(pd, int(sv.split()[1]) - 1) if sv else None, + inputs=[processed_data_state, normal_view_selector], + outputs=[normal_map] + ) + + # 测量功能 + measure_image.select( + fn=measure, + inputs=[processed_data_state, measure_points_state, measure_view_selector], + outputs=[measure_image, measure_points_state, measure_text] + ) + prev_measure_btn.click( + fn=lambda pd, cs: navigate_measure_view(pd, cs, -1), + inputs=[processed_data_state, measure_view_selector], + outputs=[measure_view_selector, measure_image, measure_points_state] + ) + next_measure_btn.click( + fn=lambda pd, cs: navigate_measure_view(pd, cs, 1), + inputs=[processed_data_state, measure_view_selector], + outputs=[measure_view_selector, measure_image, measure_points_state] + ) + measure_view_selector.change( + fn=lambda pd, sv: update_measure_view(pd, int(sv.split()[1]) - 1) if sv else (None, []), + inputs=[processed_data_state, measure_view_selector], + outputs=[measure_image, measure_points_state] + ) + +# 启动信息 +print("\n" + "="*60) +print("🚀 MapAnything V8 - 3D重建与物体分割") +print("="*60) +print("📊 核心技术: 自适应DBSCAN聚类 + 多视图融合") +print(f"🔧 质量控制: 置信度≥{MIN_DETECTION_CONFIDENCE} | 面积≥{MIN_MASK_AREA}px") +print(f"🎯 聚类半径: 沙发{DBSCAN_EPS_CONFIG['sofa']}m | 桌子{DBSCAN_EPS_CONFIG['table']}m | 窗户{DBSCAN_EPS_CONFIG['window']}m | 默认{DBSCAN_EPS_CONFIG['default']}m") +print("="*60 + "\n") + +demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False) +