diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2b68f5a1c2ce0c929c1f01e22e46fc267ea9a3e5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,57 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual Environment +venv/ +ENV/ +env/ +.venv + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# HuggingFace Space 临时文件 +input_images_*/ +*.glb +*.npz +flagged/ + +# 本地模型缓存(已改用 HuggingFace) +models/ + +# 日志 +*.log +logs/ + +# 测试文件 +.pytest_cache/ +.coverage +htmlcov/ + +# 系统文件 +Thumbs.db diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c0f6e3922db8931beb4f76a7fbfe22677a86a8db --- /dev/null +++ b/app.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +HuggingFace Space 入口文件 +直接导入并运行 gradio_app_v8 +""" + +import sys +from pathlib import Path + +# 添加 scripts 目录到 Python 路径 +scripts_dir = Path(__file__).parent / "scripts" +sys.path.insert(0, str(scripts_dir)) + +# 导入并运行主应用 +if __name__ == "__main__": + # 导入 gradio_app_v8(会自动启动 demo) + import gradio_app_v8 + diff --git a/mapanything/__init__.py b/mapanything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/__pycache__/__init__.cpython-312.pyc b/mapanything/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b4dfaaff3ef0e1b4c1a7b19336ecd940aaa5311 Binary files /dev/null and b/mapanything/__pycache__/__init__.cpython-312.pyc differ diff --git a/mapanything/datasets/__init__.py b/mapanything/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f03559a7104487f88ee1c9adef8aafbca34a75dd --- /dev/null +++ b/mapanything/datasets/__init__.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MapAnything Datasets +""" + +import torch + +from mapanything.datasets.wai.ase import ASEWAI # noqa +from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI # noqa +from mapanything.datasets.wai.dl3dv import DL3DVWAI # noqa +from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI # noqa +from mapanything.datasets.wai.eth3d import ETH3DWAI # noqa +from mapanything.datasets.wai.megadepth import MegaDepthWAI # noqa +from mapanything.datasets.wai.mpsd import MPSDWAI # noqa +from mapanything.datasets.wai.mvs_synth import MVSSynthWAI # noqa +from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI # noqa +from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI # noqa +from mapanything.datasets.wai.scannetpp import ScanNetPPWAI # noqa +from mapanything.datasets.wai.spring import SpringWAI # noqa +from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI # noqa +from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI # noqa +from mapanything.utils.train_tools import get_rank, get_world_size + + +def get_test_data_loader( + dataset, batch_size, num_workers=8, shuffle=False, drop_last=False, pin_mem=True +): + "Get simple PyTorch dataloader corresponding to the testing dataset" + # PyTorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + if torch.distributed.is_initialized(): + sampler = torch.utils.data.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + drop_last=drop_last, + ) + elif shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + ) + + return data_loader + + +def get_test_many_ar_data_loader( + dataset, batch_size, num_workers=8, drop_last=False, pin_mem=True +): + "Get PyTorch dataloader corresponding to the testing dataset that supports many aspect ratios" + # PyTorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + # Get BatchedMultiFeatureRandomSampler + sampler = dataset.make_sampler( + batch_size, + shuffle=True, + world_size=world_size, + rank=rank, + drop_last=drop_last, + use_dynamic_sampler=False, + ) + + # Init the data laoder + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + ) + + return data_loader + + +class DynamicBatchDatasetWrapper: + """ + Wrapper dataset that handles DynamicBatchedMultiFeatureRandomSampler output. + + The dynamic sampler returns batches (lists of tuples) instead of individual samples. + This wrapper ensures that the underlying dataset's __getitem__ method gets called + with individual tuples as expected. + """ + + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, batch_indices): + """ + Handle batch of indices from DynamicBatchedMultiFeatureRandomSampler. + + Args: + batch_indices: List of tuples like [(sample_idx, feat_idx_1, feat_idx_2, ...), ...] + + Returns: + List of samples from the underlying dataset + """ + if isinstance(batch_indices, (list, tuple)) and len(batch_indices) > 0: + # If it's a batch (list of tuples), process each item + if isinstance(batch_indices[0], (list, tuple)): + return [self.dataset[idx] for idx in batch_indices] + else: + # Single tuple, call dataset directly + return self.dataset[batch_indices] + else: + # Fallback for single index + return self.dataset[batch_indices] + + def __len__(self): + return len(self.dataset) + + def __getattr__(self, name): + # Delegate all other attributes to the wrapped dataset + return getattr(self.dataset, name) + + +def get_train_data_loader( + dataset, + max_num_of_imgs_per_gpu, + num_workers=8, + shuffle=True, + drop_last=True, + pin_mem=True, +): + "Dynamic PyTorch dataloader corresponding to the training dataset" + # PyTorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + # Get DynamicBatchedMultiFeatureRandomSampler + batch_sampler = dataset.make_sampler( + shuffle=shuffle, + world_size=world_size, + rank=rank, + drop_last=drop_last, + max_num_of_images_per_gpu=max_num_of_imgs_per_gpu, + use_dynamic_sampler=True, + ) + + # Wrap the dataset to handle batch format from dynamic sampler + wrapped_dataset = DynamicBatchDatasetWrapper(dataset) + + # Init the dynamic data loader + data_loader = torch.utils.data.DataLoader( + wrapped_dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=pin_mem, + ) + + return data_loader diff --git a/mapanything/datasets/base/__init__.py b/mapanything/datasets/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/datasets/base/base_dataset.py b/mapanything/datasets/base/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5b54e302d3d7beef0e5acc62846c3c6785fc542e --- /dev/null +++ b/mapanything/datasets/base/base_dataset.py @@ -0,0 +1,697 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Base class for MapAnything datasets. +""" + +from typing import List, Tuple, Union + +import numpy as np +import PIL +import torch +import torchvision.transforms as tvf +from scipy.spatial.transform import Rotation + +from mapanything.datasets.base.easy_dataset import EasyDataset +from mapanything.utils.cropping import ( + bbox_from_intrinsics_in_out, + camera_matrix_of_crop, + crop_image_and_other_optional_info, + rescale_image_and_other_optional_info, +) +from mapanything.utils.geometry import ( + depthmap_to_camera_coordinates, + get_absolute_pointmaps_and_rays_info, +) +from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT + + +class BaseDataset(EasyDataset): + """ + Define all basic options. + + Usage: + class MyDataset(BaseDataset): + def _get_views(self, idx): + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__( + self, + num_views: int, + variable_num_views: bool = False, + split: str = None, + covisibility_thres: float = None, + resolution: Union[int, Tuple[int, int], List[Tuple[int, int]]] = None, + principal_point_centered: bool = False, + transform: str = None, + data_norm_type: str = None, + aug_crop: int = 0, + seed: int = None, + max_num_retries: int = 5, + ): + """ + PyTorch dataset for multi-view images sampled from scenes, where the images form a single connected component. + + Args: + num_views (int): Number of views. + variable_num_views (bool): If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. + On by default for N-view train dataloader (hydra config). + split (str): 'train', 'val', 'test', etc. + covisibility_thres (float): Covisibility (%) threshold to determine if another image is a neighbor or not + resolution (int or tuple or list of tuples): Resolution of the images + principal_point_centered (bool): If True, the principal point is centered in the image. + transform (str): Transform to apply to the images. Options: + - 'colorjitter+grayscale+gaublur': + tvf.Compose([ + tvf.RandomApply([tvf.ColorJittter(0.3, 0.4, 0.2, 0.1)], p=0.75), + tvf.RandomGrayscale(p=0.05), + tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05), + ]) after ImgNorm + - 'colorjitter': tvf.ColorJittter(0.5, 0.5, 0.5, 0.1) after ImgNorm + - 'imgnorm': ImgNorm only + data_norm_type (str): Image normalization type. + For options, see UniCeption image normalization dict. + aug_crop (int): Augment crop. If int greater than 0, indicates the number of pixels to increase in target resolution. + seed (int): Seed for the random number generator. + max_num_retries (int): Maximum number of retries for loading a different sample from the dataset, if provided idx fails. + """ + self.num_views = num_views + self.variable_num_views = variable_num_views + self.num_views_min = 2 + self.split = split + self.covisibility_thres = covisibility_thres + self._set_resolutions(resolution) + self.principal_point_centered = principal_point_centered + + # Update the number of views if necessary and make it a list if variable_num_views is True + if self.variable_num_views and self.num_views > self.num_views_min: + self.num_views = list(range(self.num_views_min, self.num_views + 1)) + + # Initialize the image normalization type + if data_norm_type in IMAGE_NORMALIZATION_DICT.keys(): + self.data_norm_type = data_norm_type + image_norm = IMAGE_NORMALIZATION_DICT[data_norm_type] + ImgNorm = tvf.Compose( + [ + tvf.ToTensor(), + tvf.Normalize(mean=image_norm.mean, std=image_norm.std), + ] + ) + elif data_norm_type == "identity": + self.data_norm_type = data_norm_type + ImgNorm = tvf.Compose([tvf.ToTensor()]) + else: + raise ValueError( + f"Unknown data_norm_type: {data_norm_type}. Available options: identity or {list(IMAGE_NORMALIZATION_DICT.keys())}" + ) + + # Initialize torchvision transforms + if transform == "imgnorm": + self.transform = ImgNorm + elif transform == "colorjitter": + self.transform = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) + elif transform == "colorjitter+grayscale+gaublur": + self.transform = tvf.Compose( + [ + tvf.RandomApply([tvf.ColorJitter(0.3, 0.4, 0.2, 0.1)], p=0.75), + tvf.RandomGrayscale(p=0.05), + tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05), + ImgNorm, + ] + ) + else: + raise ValueError( + 'Unknown transform. Available options: "imgnorm", "colorjitter", "colorjitter+grayscale+gaublur"' + ) + + # Initialize the augmentation parameters + self.aug_crop = aug_crop + + # Initialize the seed for the random number generator + self.seed = seed + self._seed_offset = 0 + + # Initialize the maximum number of retries for loading a different sample from the dataset, if the first idx fails + self.max_num_retries = max_num_retries + + # Initialize the dataset type flags + self.is_metric_scale = False # by default a dataset is not metric scale, subclasses can overwrite this + self.is_synthetic = False # by default a dataset is not synthetic, subclasses can overwrite this + + def _load_data(self): + self.scenes = [] + self.num_of_scenes = len(self.scenes) + + def __len__(self): + "Length of the dataset is determined by the number of scenes in the dataset split" + return self.num_of_scenes + + def get_stats(self): + "Get the number of scenes in the dataset split" + return f"{self.num_of_scenes} scenes" + + def __repr__(self): + resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" + return ( + f"""{type(self).__name__}({self.get_stats()}, + {self.num_views=} + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace("self.", "") + .replace("\n", "") + .replace(" ", "") + ) + + def _get_views(self, idx, num_views_to_sample, resolution): + raise NotImplementedError() + + def _set_seed_offset(self, idx): + """ + Set the seed offset. This is directly added to self.seed when setting the random seed. + """ + self._seed_offset = idx + + def _set_resolutions(self, resolutions): + assert resolutions is not None, "undefined resolution" + + if isinstance(resolutions, int): + resolutions = [resolutions] + elif isinstance(resolutions, tuple): + resolutions = [resolutions] + elif isinstance(resolutions, list): + assert all(isinstance(res, tuple) for res in resolutions), ( + f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints" + ) + else: + raise ValueError( + f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints" + ) + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance(width, int), ( + f"Bad type for {width=} {type(width)=}, should be int" + ) + assert isinstance(height, int), ( + f"Bad type for {height=} {type(height)=}, should be int" + ) + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary( + self, + image, + resolution, + depthmap, + intrinsics, + additional_quantities=None, + ): + """ + Process an image by downsampling and cropping as needed to match the target resolution. + + This method performs the following operations: + 1. Converts the image to PIL.Image if necessary + 2. Crops the image centered on the principal point if requested + 3. Downsamples the image using high-quality Lanczos filtering + 4. Performs final cropping to match the target resolution + + Args: + image (numpy.ndarray or PIL.Image.Image): Input image to be processed + resolution (tuple): Target resolution as (width, height) + depthmap (numpy.ndarray): Depth map corresponding to the image + intrinsics (numpy.ndarray): Camera intrinsics matrix (3x3) + additional_quantities (dict, optional): Additional image-related data to be processed + alongside the main image with nearest interpolation. Defaults to None. + + Returns: + tuple: Processed image, depthmap, and updated intrinsics matrix. + If additional_quantities is provided, it returns those as well. + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # Cropping centered on the principal point if necessary + if self.principal_point_centered: + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + if cx < 0 or cx >= W or cy < 0 or cy >= H: + # Skip centered cropping if principal point is outside image bounds + pass + else: + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + left, top = cx - min_margin_x, cy - min_margin_y + right, bottom = cx + min_margin_x, cy + min_margin_y + crop_bbox = (left, top, right, bottom) + # Only perform the centered crop if the crop_bbox is larger than the target resolution + crop_width = right - left + crop_height = bottom - top + if crop_width > resolution[0] and crop_height > resolution[1]: + image, depthmap, intrinsics, additional_quantities = ( + crop_image_and_other_optional_info( + image=image, + crop_bbox=crop_bbox, + depthmap=depthmap, + camera_intrinsics=intrinsics, + additional_quantities=additional_quantities, + ) + ) + + # Get the target resolution for re-scaling + target_rescale_resolution = np.array(resolution) + if self.aug_crop > 1: + target_rescale_resolution += self._rng.integers(0, self.aug_crop) + + # High-quality Lanczos down-scaling if necessary + image, depthmap, intrinsics, additional_quantities = ( + rescale_image_and_other_optional_info( + image=image, + output_resolution=target_rescale_resolution, + depthmap=depthmap, + camera_intrinsics=intrinsics, + additional_quantities_to_be_resized_with_nearest=additional_quantities, + ) + ) + + # Actual cropping (if necessary) + new_intrinsics = camera_matrix_of_crop( + input_camera_matrix=intrinsics, + input_resolution=image.size, + output_resolution=resolution, + offset_factor=0.5, + ) + crop_bbox = bbox_from_intrinsics_in_out( + input_camera_matrix=intrinsics, + output_camera_matrix=new_intrinsics, + output_resolution=resolution, + ) + image, depthmap, new_intrinsics, additional_quantities = ( + crop_image_and_other_optional_info( + image=image, + crop_bbox=crop_bbox, + depthmap=depthmap, + camera_intrinsics=intrinsics, + additional_quantities=additional_quantities, + ) + ) + + # Return the output + if additional_quantities is not None: + return image, depthmap, new_intrinsics, additional_quantities + else: + return image, depthmap, new_intrinsics + + def _random_walk_sampling( + self, + scene_pairwise_covisibility, + num_of_samples, + max_retries=4, + use_bidirectional_covis=True, + ): + """ + Randomly samples S indices from an N x N covisibility matrix by forming adjacency edges such that the resulting subgraph (given by the indices) is connected. + If the current node has no new unvisited neighbors, backtracking occurs. + Retries with different starting indices if the desired number of samples is not reached, excluding previously visited components. + + Args: + scene_pairwise_covisibility : np.ndarray (mmap) + N x N covisibility matrix for the scene, where N is the number of views in the scene. + num_of_samples : int + The desired number of nodes to sample (num_of_samples < N). + max_retries : int + The maximum number of retries with different starting indices. + use_bidirectional_covis : bool + Whether to compute bidirectional covisibility by averaging row and column values. + If False, uses only row access (faster for large memory-mapped arrays). + Defaults to True. + + Returns: + np.ndarray + An array of sampled indices forming a connected subgraph. + """ + excluded_nodes = set() + best_walk = [] # To keep track of the best walk found + for _ in range(max_retries): + visited = set() + walk = [] # List to store the random walk sampling order + stack = [] # Stack for backtracking + + # Choose a random starting index that is not in the excluded set + all_nodes = set(range(len(scene_pairwise_covisibility))) + available_nodes = list(all_nodes - excluded_nodes) + if not available_nodes: + break # No more nodes to try + start = self._rng.choice(available_nodes) + walk.append(start) + visited.add(start) + stack.append(start) + + # Continue until we have sampled S indices or all expandable nodes are exhausted + while len(walk) < num_of_samples and stack: + current = stack[-1] + # Get the pairwise covisibility for the current node + if use_bidirectional_covis: + # Use bidirectional covisibility (slower for large memory-mapped arrays) + pairwise_covisibility = ( + scene_pairwise_covisibility[current, :] + + scene_pairwise_covisibility[:, current].T + ) / 2 + else: + # Use only row access (faster for large memory-mapped arrays) + pairwise_covisibility = scene_pairwise_covisibility[current, :] + # Normalize the covisibility using self covisibility + pairwise_covisibility = pairwise_covisibility / ( + pairwise_covisibility[current] + 1e-8 + ) + # Assign overlap score of zero to self-pairs + pairwise_covisibility[current] = 0 + # Threshold the covisibility to get adjacency list for the current node + adjacency_list_for_current = ( + pairwise_covisibility > self.covisibility_thres + ).astype(int) + adjacency_list_for_current = np.flatnonzero(adjacency_list_for_current) + # Get all unvisited neighbors + candidates = [ + idx for idx in adjacency_list_for_current if idx not in visited + ] # Remove visited nodes + if candidates: + # Randomly select one of the unvisited overlapping neighbors + next_node = self._rng.choice(candidates) + walk.append(next_node) + visited.add(next_node) + stack.append(next_node) + else: + # If no unvisited neighbor is available, backtrack + stack.pop() + + # Update the best walk if the current walk is larger + if len(walk) > len(best_walk): + best_walk = walk + + # If we have enough samples, return the result + if len(walk) >= num_of_samples: + return np.array(walk) + + # Add all visited nodes to the excluded set + excluded_nodes.update(visited) + + # If all retries are exhausted and we still don't have enough samples, return the best walk found + return np.array(best_walk) + + def _sample_view_indices( + self, + num_views_to_sample, + num_views_in_scene, + scene_pairwise_covisibility, + use_bidirectional_covis=True, + ): + """ + Sample view indices from a scene based on the adjacency list and the number of views to sample. + + Args: + num_views_to_sample (int): Number of views to sample. + num_views_in_scene (int): Total number of views available in the scene. + scene_pairwise_covisibility (np.ndarray): N x N covisibility matrix for the scene, where N is the number of views in the scene. + use_bidirectional_covis (bool): Whether to compute bidirectional covisibility by averaging row and column values. + If False, uses only row access (faster for large memory-mapped arrays). + + Returns: + numpy.ndarray: Array of sampled view indices. + """ + if num_views_to_sample == num_views_in_scene: + # Select all views in the scene + view_indices = self._rng.permutation(num_views_in_scene) + elif num_views_to_sample > num_views_in_scene: + # Select all views in the scene and repeat them to get the desired number of views + view_indices = self._rng.choice( + num_views_in_scene, size=num_views_to_sample, replace=True + ) + else: + # Select a subset of single component connected views in the scene using random walk sampling + view_indices = self._random_walk_sampling( + scene_pairwise_covisibility, + num_views_to_sample, + use_bidirectional_covis=use_bidirectional_covis, + ) + # If the required num of views can't be obtained even with 4 retries, repeat existing indices to get the desired number of views + if len(view_indices) < num_views_to_sample: + view_indices = self._rng.choice( + view_indices, size=num_views_to_sample, replace=True + ) + + return view_indices + + def _getitem_fn(self, idx): + if isinstance(idx, tuple): + # The idx is a tuple if specifying the aspect-ratio or/and the number of views + if isinstance(self.num_views, int): + idx, ar_idx = idx + else: + idx, ar_idx, num_views_to_sample_idx = idx + else: + assert len(self._resolutions) == 1 + assert isinstance(self.num_views, int) + ar_idx = 0 + + # Setup the rng + if self.seed: # reseed for each _getitem_fn + # Leads to deterministic sampling where repeating self.seed and self._seed_offset yields the same multi-view set again + # Scenes will be repeated if size of dataset is artificially increased using "N @" or "N *" + # When scenes are repeated, self._seed_offset is increased to ensure new multi-view sets + # This is useful for evaluation if the number of dataset scenes is < N, yet we want unique multi-view sets each iter + self._rng = np.random.default_rng(seed=self.seed + self._seed_offset + idx) + elif not hasattr(self, "_rng"): + seed = torch.initial_seed() # this is different for each dataloader process + self._rng = np.random.default_rng(seed=seed) + + # Get the views for the given index and check that the number of views is correct + resolution = self._resolutions[ar_idx] + if isinstance(self.num_views, int): + num_views_to_sample = self.num_views + else: + num_views_to_sample = self.num_views[num_views_to_sample_idx] + views = self._get_views(idx, num_views_to_sample, resolution) + if isinstance(self.num_views, int): + assert len(views) == self.num_views + else: + assert len(views) in self.num_views + + for v, view in enumerate(views): + # Store the index and other metadata + view["idx"] = (idx, ar_idx, v) + view["is_metric_scale"] = self.is_metric_scale + view["is_synthetic"] = self.is_synthetic + + # Check the depth, intrinsics, and pose data (also other data if present) + assert "camera_intrinsics" in view + assert "camera_pose" in view + assert np.isfinite(view["camera_pose"]).all(), ( + f"NaN or infinite values in camera pose for view {view_name(view)}" + ) + assert np.isfinite(view["depthmap"]).all(), ( + f"NaN or infinite values in depthmap for view {view_name(view)}" + ) + assert "valid_mask" not in view + assert "pts3d" not in view, ( + f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + ) + if "prior_depth_z" in view: + assert np.isfinite(view["prior_depth_z"]).all(), ( + f"NaN or infinite values in prior_depth_z for view {view_name(view)}" + ) + if "non_ambiguous_mask" in view: + assert np.isfinite(view["non_ambiguous_mask"]).all(), ( + f"NaN or infinite values in non_ambiguous_mask for view {view_name(view)}" + ) + + # Encode the image + width, height = view["img"].size + view["true_shape"] = np.int32((height, width)) + view["img"] = self.transform(view["img"]) + view["data_norm_type"] = self.data_norm_type + + # Compute the pointmaps, raymap and depth along ray + ( + pts3d, + valid_mask, + ray_origins_world, + ray_directions_world, + depth_along_ray, + ray_directions_cam, + pts3d_cam, + ) = get_absolute_pointmaps_and_rays_info(**view) + view["pts3d"] = pts3d + view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) + view["depth_along_ray"] = depth_along_ray + view["ray_directions_cam"] = ray_directions_cam + view["pts3d_cam"] = pts3d_cam + + # Compute the prior depth along ray if present + if "prior_depth_z" in view: + prior_pts3d, _ = depthmap_to_camera_coordinates( + view["prior_depth_z"], view["camera_intrinsics"] + ) + view["prior_depth_along_ray"] = np.linalg.norm(prior_pts3d, axis=-1) + view["prior_depth_along_ray"] = view["prior_depth_along_ray"][..., None] + del view["prior_depth_z"] + + # Convert ambiguous mask dtype to match valid mask dtype + if "non_ambiguous_mask" in view: + view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype( + view["valid_mask"].dtype + ) + else: + ambiguous_mask = view["depthmap"] < 0 + view["non_ambiguous_mask"] = ~ambiguous_mask + view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype( + view["valid_mask"].dtype + ) + + # Check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + + # Check shapes + assert view["depthmap"].shape == view["img"].shape[1:] + assert view["depthmap"].shape == view["pts3d"].shape[:2] + assert view["depthmap"].shape == view["valid_mask"].shape + assert view["depthmap"].shape == view["depth_along_ray"].shape[:2] + assert view["depthmap"].shape == view["ray_directions_cam"].shape[:2] + assert view["depthmap"].shape == view["pts3d_cam"].shape[:2] + if "prior_depth_along_ray" in view: + assert view["depthmap"].shape == view["prior_depth_along_ray"].shape[:2] + if "non_ambiguous_mask" in view: + assert view["depthmap"].shape == view["non_ambiguous_mask"].shape + + # Expand the last dimension of the depthmap + view["depthmap"] = view["depthmap"][..., None] + + # Append RNG state to the views, this allows to check whether the RNG is in the same state each time + view["rng"] = int.from_bytes(self._rng.bytes(4), "big") + + # Compute and store the quaternions and translation for the camera poses + # Notation is (x, y, z, w) for quaternions + # This also ensures that the camera poses have a positive determinant (right-handed coordinate system) + view["camera_pose_quats"] = ( + Rotation.from_matrix(view["camera_pose"][:3, :3]) + .as_quat() + .astype(view["camera_pose"].dtype) + ) + view["camera_pose_trans"] = view["camera_pose"][:3, 3].astype( + view["camera_pose"].dtype + ) + + # Check the pointmaps, rays, depth along ray, and camera pose quaternions and translation to ensure they are finite + assert np.isfinite(view["pts3d"]).all(), ( + f"NaN in pts3d for view {view_name(view)}" + ) + assert np.isfinite(view["valid_mask"]).all(), ( + f"NaN in valid_mask for view {view_name(view)}" + ) + assert np.isfinite(view["depth_along_ray"]).all(), ( + f"NaN in depth_along_ray for view {view_name(view)}" + ) + assert np.isfinite(view["ray_directions_cam"]).all(), ( + f"NaN in ray_directions_cam for view {view_name(view)}" + ) + assert np.isfinite(view["pts3d_cam"]).all(), ( + f"NaN in pts3d_cam for view {view_name(view)}" + ) + assert np.isfinite(view["camera_pose_quats"]).all(), ( + f"NaN in camera_pose_quats for view {view_name(view)}" + ) + assert np.isfinite(view["camera_pose_trans"]).all(), ( + f"NaN in camera_pose_trans for view {view_name(view)}" + ) + if "prior_depth_along_ray" in view: + assert np.isfinite(view["prior_depth_along_ray"]).all(), ( + f"NaN in prior_depth_along_ray for view {view_name(view)}" + ) + + return views + + def __getitem__(self, idx): + if self.max_num_retries == 0: + return self._getitem_fn(idx) + + num_retries = 0 + while num_retries <= self.max_num_retries: + try: + return self._getitem_fn(idx) + except Exception as e: + scene_idx = idx[0] if isinstance(idx, tuple) else idx + print( + f"Error in {type(self).__name__}.__getitem__ for scene_idx={scene_idx}: {e}" + ) + + if num_retries >= self.max_num_retries: + print( + f"Max retries ({self.max_num_retries}) reached, raising the exception" + ) + raise e + + # Retry with a different scene index + num_retries += 1 + if isinstance(idx, tuple): + # The scene index is the first element of the tuple + idx_list = list(idx) + idx_list[0] = np.random.randint(0, len(self)) + idx = tuple(idx_list) + else: + # The scene index is idx + idx = np.random.randint(0, len(self)) + scene_idx = idx[0] if isinstance(idx, tuple) else idx + print( + f"Retrying with scene_idx={scene_idx} ({num_retries} of {self.max_num_retries})" + ) + + +def is_good_type(v): + """ + Check if a value has an acceptable data type for processing in the dataset. + + Args: + v: The value to check. + + Returns: + tuple: A tuple containing: + - bool: True if the type is acceptable, False otherwise. + - str or None: Error message if the type is not acceptable, None otherwise. + """ + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + """ + Generate a string identifier for a view based on its dataset, label, and instance. + + Args: + view (dict): Dictionary containing view information with 'dataset', 'label', and 'instance' keys. + batch_index (int, optional): Index to select from batched data. Defaults to None. + + Returns: + str: A formatted string in the form "dataset/label/instance". + """ + + def sel(x): + return x[batch_index] if batch_index not in (None, slice(None)) else x + + db = sel(view["dataset"]) + label = sel(view["label"]) + instance = sel(view["instance"]) + return f"{db}/{label}/{instance}" diff --git a/mapanything/datasets/base/batched_sampler.py b/mapanything/datasets/base/batched_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..15c77f8ec39ba1382048a9bffa94e8b2ed919ff5 --- /dev/null +++ b/mapanything/datasets/base/batched_sampler.py @@ -0,0 +1,431 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utilities for random sampling under a single or multiple constraints + +References: DUSt3R +""" + +import numpy as np +import torch + + +def round_by(total, multiple, up=False): + """ + Round a number to the nearest multiple of another number. + + Args: + total (int): The number to round + multiple (int): The multiple to round to + up (bool, optional): Whether to round up. Defaults to False. + + Returns: + int: The rounded number + """ + if up: + total = total + multiple - 1 + return (total // multiple) * multiple + + +class BatchedRandomSampler: + """ + Random sampling under a constraint: each sample in the batch has the same feature, + which is chosen randomly from a known pool of 'features' for each batch. + + For instance, the 'feature' could be the image aspect-ratio. + + The index returned is a tuple (sample_idx, feat_idx). + This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. + """ + + def __init__( + self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True + ): + """ + Args: + dataset: Dataset to sample from + batch_size: Number of samples per batch + pool_size: Integer representing the size of feature pool + world_size: Number of distributed processes + rank: Rank of the current process + drop_last: Whether to drop the last incomplete batch + """ + self.batch_size = batch_size + self.pool_size = pool_size + + self.len_dataset = N = len(dataset) + self.total_size = round_by(N, batch_size * world_size) if drop_last else N + assert world_size == 1 or drop_last, ( + "must drop the last batch in distributed mode" + ) + + # Distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + + def __len__(self): + """ + Get the length of the sampler. + + Returns: + int: The number of samples in the sampler for the current process + """ + return self.total_size // self.world_size + + def set_epoch(self, epoch): + """ + Set the epoch for this sampler. + + This should be called before each epoch to ensure proper shuffling of the data. + + Args: + epoch (int): The current epoch number + """ + self.epoch = epoch + + def __iter__(self): + """ + Iterator over the indices. + + This method generates random indices for each batch, ensuring that all samples + within a batch have the same feature index for the given feature pool. + + Yields: + tuple: A tuple containing (sample_idx, feat_idx) + """ + # Prepare RNG + if self.epoch is None: + assert self.world_size == 1 and self.rank == 0, ( + "use set_epoch() if distributed mode is used" + ) + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # Random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # Random feat_idxs (same across each batch) + n_batches = (self.total_size + self.batch_size - 1) // self.batch_size + feat_idxs = rng.integers(self.pool_size, size=n_batches) + feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) + feat_idxs = feat_idxs.ravel()[: self.total_size] + + # Put them together + idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) + + # Distributed sampler: we select a subset of batches + # Make sure the slice for each node is aligned with batch_size + size_per_proc = self.batch_size * ( + (self.total_size + self.world_size * self.batch_size - 1) + // (self.world_size * self.batch_size) + ) + idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc] + + yield from (tuple(idx) for idx in idxs) + + +class BatchedMultiFeatureRandomSampler: + """ + Random sampling under multiple constraints: each sample in the batch has the same features, + which are chosen randomly from known pools of 'features' for each batch. + + For instance, the 'features' could be the image aspect-ratio and scene type. + + The index returned is a tuple (sample_idx, feat_idx_1, feat_idx_2, ...). + This sampler ensures that each series of `batch_size` indices has the same feature indices. + """ + + def __init__( + self, dataset, batch_size, pool_sizes, world_size=1, rank=0, drop_last=True + ): + """ + Args: + dataset: Dataset to sample from + batch_size: Number of samples per batch + pool_sizes: List of integers representing the size of each feature pool + world_size: Number of distributed processes + rank: Rank of the current process + drop_last: Whether to drop the last incomplete batch + """ + self.batch_size = batch_size + self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes] + + self.len_dataset = N = len(dataset) + self.total_size = round_by(N, batch_size * world_size) if drop_last else N + assert world_size == 1 or drop_last, ( + "must drop the last batch in distributed mode" + ) + + # Distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + + def __len__(self): + """ + Get the length of the sampler. + + Returns: + int: The number of samples in the sampler for the current process + """ + return self.total_size // self.world_size + + def set_epoch(self, epoch): + """ + Set the epoch for this sampler. + + This should be called before each epoch to ensure proper shuffling of the data. + + Args: + epoch (int): The current epoch number + """ + self.epoch = epoch + + def __iter__(self): + """ + Iterator over the indices. + + This method generates random indices for each batch, ensuring that all samples + within a batch have the same feature indices for multiple features. + + Yields: + tuple: A tuple containing (sample_idx, feat_idx_1, feat_idx_2, ...) + """ + # Prepare RNG + if self.epoch is None: + assert self.world_size == 1 and self.rank == 0, ( + "use set_epoch() if distributed mode is used" + ) + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # Random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # Random feat_idxs (same across each batch) + n_batches = (self.total_size + self.batch_size - 1) // self.batch_size + + # Generate feature indices for each feature pool + all_feat_idxs = [] + for pool_size in self.pool_sizes: + feat_idxs = rng.integers(pool_size, size=n_batches) + feat_idxs = np.broadcast_to( + feat_idxs[:, None], (n_batches, self.batch_size) + ) + feat_idxs = feat_idxs.ravel()[: self.total_size] + all_feat_idxs.append(feat_idxs) + + # Put them together + idxs = np.column_stack( + [sample_idxs] + all_feat_idxs + ) # shape = (total_size, 1 + len(pool_sizes)) + + # Distributed sampler: we select a subset of batches + # Make sure the slice for each node is aligned with batch_size + size_per_proc = self.batch_size * ( + (self.total_size + self.world_size * self.batch_size - 1) + // (self.world_size * self.batch_size) + ) + idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc] + + yield from (tuple(idx) for idx in idxs) + + +class DynamicBatchedMultiFeatureRandomSampler: + """ + Random sampling under multiple constraints with dynamic batch size: + each sample in the batch has the same features, which are chosen randomly + from known pools of 'features' for each batch. + + The batch size is dynamically determined based on a specified feature index, + using a direct mapping from feature values to batch sizes. + + For instance, if one of the features is the number of images in a multi-view set, + you can specify different batch sizes for different numbers of images to optimize + GPU memory usage. This is achieved by using the feature_to_batch_size_map parameter + to directly specify what batch size to use for each feature value. + + The returned index is a list of tuples [(sample_idx, feat_idx_1, feat_idx_2, ...), ...]. + """ + + def __init__( + self, + dataset, + pool_sizes, + scaling_feature_idx=0, + feature_to_batch_size_map=None, + world_size=1, + rank=0, + drop_last=True, + ): + """ + Args: + dataset: Dataset to sample from + pool_sizes: List of integers representing the size of each feature pool + scaling_feature_idx: Index of the feature to use for determining batch size (0-based index into pool_sizes) + feature_to_batch_size_map: Optional function or dict that maps feature values directly to batch sizes. + For example, if the feature represents number of views, this maps number of views + to appropriate batch size that can fit in GPU memory. + If None, uses a default batch size of 1 for all feature values. + world_size: Number of distributed processes + rank: Rank of the current process + drop_last: Whether to drop the last incomplete batch + """ + self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes] + self.scaling_feature_idx = scaling_feature_idx + + # Ensure scaling_feature_idx is valid + if scaling_feature_idx < 0 or scaling_feature_idx >= len(self.pool_sizes): + raise ValueError( + f"scaling_feature_idx must be between 0 and {len(self.pool_sizes) - 1}" + ) + + # Set up mapping from feature values to batch sizes + self.feature_to_batch_size_map = feature_to_batch_size_map + if self.feature_to_batch_size_map is None: + # Default: batch size of 1 for all feature values + self.feature_to_batch_size_map = { + i: 1 for i in range(self.pool_sizes[scaling_feature_idx]) + } + + self.len_dataset = N = len(dataset) + + # We don't know the exact batch size yet, so we use a large number for total_size + # This will be adjusted during iteration + self.total_size = N + + # Distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + self.drop_last = drop_last + + def __len__(self): + """ + Get the approximate length of the sampler. + + Since batch size varies, this is an estimate based on the largest batch size + in the mapping, which provides a lower bound on the number of batches. + + Returns: + int: The estimated minimum number of samples in the sampler for the current process + """ + # Find the largest batch size in the mapping + if callable(self.feature_to_batch_size_map): + # If it's a function, sample some values to find the maximum + batch_sizes = [ + self.feature_to_batch_size_map(i) + for i in range(self.pool_sizes[self.scaling_feature_idx]) + ] + max_batch_size = max(batch_sizes) + else: + # If it's a dict or similar, find the maximum directly + max_batch_size = max(self.feature_to_batch_size_map.values()) + + # Ensure minimum batch size of 1 + max_batch_size = max(1, max_batch_size) + + # Estimate total batches using the largest batch size + # This gives a lower bound on the number of batches + total_batches = self.total_size // max_batch_size + if not self.drop_last and self.total_size % max_batch_size > 0: + total_batches += 1 + + # Distribute among processes + return total_batches // self.world_size + + def set_epoch(self, epoch): + """ + Set the epoch for this sampler. + + This should be called before each epoch to ensure proper shuffling of the data. + + Args: + epoch (int): The current epoch number + """ + self.epoch = epoch + + def __iter__(self): + """ + Iterator over the indices with dynamic batch sizes. + + This method generates random indices for each batch, ensuring that all samples + within a batch have the same feature indices for multiple features. + The batch size is determined directly from the feature_to_batch_size_map. + + The iterator enforces the length returned by __len__() by stopping after + exactly that many batches have been yielded for this process. + + Yields: + list of tuples: A batch of tuples, each containing (sample_idx, feat_idx_1, feat_idx_2, ...) + """ + # Prepare RNG + if self.epoch is None: + assert self.world_size == 1 and self.rank == 0, ( + "use set_epoch() if distributed mode is used" + ) + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # Random indices for the entire dataset + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # Get the target number of batches for this process (enforce strict length) + target_batches_for_process = len(self) + batches_yielded_for_process = 0 + + # Process indices in batches with dynamic sizing + idx = 0 + batch_idx = 0 # Track batch index for even distribution + while idx < len(sample_idxs) and ( + batches_yielded_for_process < target_batches_for_process + ): + # Randomly select feature indices for this batch + feat_idxs = [rng.integers(pool_size) for pool_size in self.pool_sizes] + + # Get the scaling feature value + scaling_feat = feat_idxs[self.scaling_feature_idx] + + # Get the batch size directly from the mapping + if callable(self.feature_to_batch_size_map): + batch_size = self.feature_to_batch_size_map(scaling_feat) + else: + batch_size = self.feature_to_batch_size_map.get(scaling_feat, 1) + + # Ensure minimum batch size of 1 + batch_size = max(1, batch_size) + + # Ensure we don't go beyond available samples + remaining = len(sample_idxs) - idx + if remaining < batch_size: + if self.drop_last: + break + batch_size = remaining + + # Create batch with consistent feature indices + batch = [] + for i in range(batch_size): + if idx + i < len(sample_idxs): + sample_idx = sample_idxs[idx + i] + batch.append(tuple([sample_idx] + feat_idxs)) + + # Distribute batches among processes in round-robin fashion + if len(batch) > 0 and (batch_idx % self.world_size == self.rank): + yield batch + batches_yielded_for_process += 1 + + batch_idx += 1 # Increment batch index + idx += batch_size diff --git a/mapanything/datasets/base/easy_dataset.py b/mapanything/datasets/base/easy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..55a9268e53f8a6aa456dbfe73d8ed213e639bc3a --- /dev/null +++ b/mapanything/datasets/base/easy_dataset.py @@ -0,0 +1,478 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Base dataset class that enables easy resizing and combining + +References: DUSt3R +""" + +import numpy as np + +from mapanything.datasets.base.batched_sampler import ( + BatchedMultiFeatureRandomSampler, + DynamicBatchedMultiFeatureRandomSampler, +) + + +class EasyDataset: + """ + Dataset that can be easily resized and combined. + + Examples: + --------- + 2 * dataset ==> Duplicate each element 2x + + 10 @ dataset ==> Set the size to 10 (random sampling, duplicates if necessary) + + Dataset1 + Dataset2 ==> Concatenate datasets + """ + + def __add__(self, other): + """ + Concatenate this dataset with another dataset. + + Args: + other (EasyDataset): Another dataset to concatenate with this one + + Returns: + CatDataset: A new dataset that is the concatenation of this dataset and the other + """ + return CatDataset([self, other]) + + def __rmul__(self, factor): + """ + Multiply the dataset by a factor, duplicating each element. + + Args: + factor (int): Number of times to duplicate each element + + Returns: + MulDataset: A new dataset with each element duplicated 'factor' times + """ + return MulDataset(factor, self) + + def __rmatmul__(self, factor): + """ + Resize the dataset to a specific size using random sampling. + + Args: + factor (int): The new size of the dataset + + Returns: + ResizedDataset: A new dataset with the specified size + """ + return ResizedDataset(factor, self) + + def set_epoch(self, epoch): + """ + Set the current epoch for all constituent datasets. + + Args: + epoch (int): The current epoch number + """ + pass # nothing to do by default + + def make_sampler( + self, + batch_size=None, + shuffle=True, + world_size=1, + rank=0, + drop_last=True, + max_num_of_images_per_gpu=None, + use_dynamic_sampler=True, + ): + """ + Create a sampler for this dataset. + + Args: + batch_size (int, optional): Number of samples per batch (used for non-dynamic sampler). Defaults to None. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True. + world_size (int, optional): Number of distributed processes. Defaults to 1. + rank (int, optional): Rank of the current process. Defaults to 0. + drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True. + max_num_of_images_per_gpu (int, optional): Maximum number of images per GPU for dynamic batching. Defaults to None. + use_dynamic_sampler (bool, optional): Whether to use the dynamic sampler. Defaults to True. + + Returns: + DynamicBatchedMultiFeatureRandomSampler or BatchedMultiFeatureRandomSampler: A sampler for this dataset + + Raises: + NotImplementedError: If shuffle is False + ValueError: If num_views has an invalid type or required parameters are missing + """ + if not (shuffle): + raise NotImplementedError() # cannot deal yet + + if isinstance(self.num_views, int): + num_of_aspect_ratios = len(self._resolutions) + feature_pool_sizes = [num_of_aspect_ratios] + scaling_feature_idx = 0 # Use aspect ratio as scaling feature + elif isinstance(self.num_views, list): + num_of_aspect_ratios = len(self._resolutions) + num_of_num_views = len(self.num_views) + feature_pool_sizes = [num_of_aspect_ratios, num_of_num_views] + scaling_feature_idx = 1 # Use num_views as scaling feature + else: + raise ValueError( + f"Bad type for {self.num_views=}, should be int or list of ints" + ) + + if use_dynamic_sampler: + if max_num_of_images_per_gpu is None: + raise ValueError( + "max_num_of_images_per_gpu must be provided when using dynamic sampler" + ) + + # Create feature-to-batch-size mapping + if isinstance(self.num_views, list): + # Map num_views_idx to batch size: max(1, max_num_of_images_per_gpu // (num_views_idx + dataset.num_views_min)) + feature_to_batch_size_map = {} + for num_views_idx, num_views in enumerate(self.num_views): + batch_size_for_multi_view_sets = max( + 1, max_num_of_images_per_gpu // num_views + ) + feature_to_batch_size_map[num_views_idx] = ( + batch_size_for_multi_view_sets + ) + else: + # For fixed num_views, use a simple mapping + feature_to_batch_size_map = { + 0: max(1, max_num_of_images_per_gpu // self.num_views) + } + + return DynamicBatchedMultiFeatureRandomSampler( + self, + pool_sizes=feature_pool_sizes, + scaling_feature_idx=scaling_feature_idx, + feature_to_batch_size_map=feature_to_batch_size_map, + world_size=world_size, + rank=rank, + drop_last=drop_last, + ) + else: + if batch_size is None: + raise ValueError( + "batch_size must be provided when not using dynamic sampler" + ) + + return BatchedMultiFeatureRandomSampler( + self, + batch_size, + feature_pool_sizes, + world_size=world_size, + rank=rank, + drop_last=drop_last, + ) + + +class MulDataset(EasyDataset): + """Artificially augmenting the size of a dataset.""" + + multiplicator: int + + def __init__(self, multiplicator, dataset): + """ + Initialize a dataset that artificially augments the size of another dataset. + + Args: + multiplicator (int): Factor by which to multiply the dataset size + dataset (EasyDataset): The dataset to augment + """ + assert isinstance(multiplicator, int) and multiplicator > 0 + self.multiplicator = multiplicator + self.dataset = dataset + + def __len__(self): + """ + Get the length of the dataset. + + Returns: + int: The number of samples in the dataset + """ + return self.multiplicator * len(self.dataset) + + def __repr__(self): + """ + Get a string representation of the dataset. + + Returns: + str: String representation showing the multiplication factor and the original dataset + """ + return f"{self.multiplicator}*{repr(self.dataset)}" + + def __getitem__(self, idx): + """ + Get an item from the dataset. + + Args: + idx: Index or tuple of indices to retrieve + + Returns: + The item at the specified index from the original dataset + """ + if isinstance(idx, tuple): + other = idx[1:] + idx = idx[0] + new_idx = (idx // self.multiplicator, *other) + return self.dataset[new_idx] + else: + return self.dataset[idx // self.multiplicator] + + @property + def _resolutions(self): + """ + Get the resolutions of the dataset. + + Returns: + The resolutions from the original dataset + """ + return self.dataset._resolutions + + @property + def num_views(self): + """ + Get the number of views used for the dataset. + + Returns: + int or list: The number of views parameter from the original dataset + """ + return self.dataset.num_views + + +class ResizedDataset(EasyDataset): + """Artificially changing the size of a dataset.""" + + new_size: int + + def __init__(self, new_size, dataset): + """ + Initialize a dataset with an artificially changed size. + + Args: + new_size (int): The new size of the dataset + dataset (EasyDataset): The original dataset + """ + assert isinstance(new_size, int) and new_size > 0 + self.new_size = new_size + self.dataset = dataset + + def __len__(self): + """ + Get the length of the dataset. + + Returns: + int: The new size of the dataset + """ + return self.new_size + + def __repr__(self): + """ + Get a string representation of the dataset. + + Returns: + str: String representation showing the new size and the original dataset + """ + size_str = str(self.new_size) + for i in range((len(size_str) - 1) // 3): + sep = -4 * i - 3 + size_str = size_str[:sep] + "_" + size_str[sep:] + return f"{size_str} @ {repr(self.dataset)}" + + def set_epoch(self, epoch): + """ + Set the current epoch and generate a new random mapping of indices. + + This method must be called before using __getitem__. + + Args: + epoch (int): The current epoch number + """ + # This random shuffle only depends on the epoch + rng = np.random.default_rng(seed=epoch + 777) + + # Shuffle all indices + perm = rng.permutation(len(self.dataset)) + + # Calculate how many repetitions we need + num_repetitions = 1 + (len(self) - 1) // len(self.dataset) + + # Rotary extension until target size is met + shuffled_idxs = np.concatenate([perm] * num_repetitions) + self._idxs_mapping = shuffled_idxs[: self.new_size] + + # Generate the seed offset for each repetition + # This is needed to ensure we see unique samples when we repeat a scene + seed_offset_per_repetition = [ + np.full(len(self.dataset), i) for i in range(num_repetitions) + ] + seed_offset_idxs = np.concatenate(seed_offset_per_repetition) + self._idxs_seed_offset = seed_offset_idxs[: self.new_size] + + assert len(self._idxs_mapping) == self.new_size + assert len(self._idxs_seed_offset) == self.new_size + + def __getitem__(self, idx): + """ + Get an item from the dataset. + + Args: + idx: Index or tuple of indices to retrieve + + Returns: + The item at the mapped index from the original dataset + + Raises: + AssertionError: If set_epoch has not been called + """ + assert hasattr(self, "_idxs_mapping"), ( + "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()" + ) + if isinstance(idx, tuple): + other = idx[1:] + idx = idx[0] + self.dataset._set_seed_offset(self._idxs_seed_offset[idx]) + new_idx = (self._idxs_mapping[idx], *other) + return self.dataset[new_idx] + else: + self.dataset._set_seed_offset(self._idxs_seed_offset[idx]) + return self.dataset[self._idxs_mapping[idx]] + + @property + def _resolutions(self): + """ + Get the resolutions of the dataset. + + Returns: + The resolutions from the original dataset + """ + return self.dataset._resolutions + + @property + def num_views(self): + """ + Get the number of views used for the dataset. + + Returns: + int or list: The number of views parameter from the original dataset + """ + return self.dataset.num_views + + +class CatDataset(EasyDataset): + """Concatenation of several datasets""" + + def __init__(self, datasets): + """ + Initialize a dataset that is a concatenation of several datasets. + + Args: + datasets (list): List of EasyDataset instances to concatenate + """ + for dataset in datasets: + assert isinstance(dataset, EasyDataset) + self.datasets = datasets + self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) + + def __len__(self): + """ + Get the length of the concatenated dataset. + + Returns: + int: Total number of samples across all datasets + """ + return self._cum_sizes[-1] + + def __repr__(self): + """ + Get a string representation of the concatenated dataset. + + Returns: + str: String representation showing all concatenated datasets joined by '+' + """ + # Remove uselessly long transform + return " + ".join( + repr(dataset).replace( + ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))", + "", + ) + for dataset in self.datasets + ) + + def set_epoch(self, epoch): + """ + Set the current epoch for all constituent datasets. + + Args: + epoch (int): The current epoch number + """ + for dataset in self.datasets: + dataset.set_epoch(epoch) + + def __getitem__(self, idx): + """ + Get an item from the concatenated dataset. + + Args: + idx: Index or tuple of indices to retrieve + + Returns: + The item at the specified index from the appropriate constituent dataset + + Raises: + IndexError: If the index is out of range + """ + other = None + if isinstance(idx, tuple): + other = idx[1:] + idx = idx[0] + + if not (0 <= idx < len(self)): + raise IndexError() + + db_idx = np.searchsorted(self._cum_sizes, idx, "right") + dataset = self.datasets[db_idx] + new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) + + if other is not None: + new_idx = (new_idx, *other) + return dataset[new_idx] + + @property + def _resolutions(self): + """ + Get the resolutions of the dataset. + + Returns: + The resolutions from the first dataset (all datasets must have the same resolutions) + + Raises: + AssertionError: If datasets have different resolutions + """ + resolutions = self.datasets[0]._resolutions + for dataset in self.datasets[1:]: + assert tuple(dataset._resolutions) == tuple(resolutions), ( + "All datasets must have the same resolutions" + ) + return resolutions + + @property + def num_views(self): + """ + Get the number of views used for the dataset. + + Returns: + int or list: The number of views parameter from the first dataset + + Raises: + AssertionError: If datasets have different num_views + """ + num_views = self.datasets[0].num_views + for dataset in self.datasets[1:]: + assert dataset.num_views == num_views, ( + "All datasets must have the same num_views and variable_num_views parameters" + ) + return num_views diff --git a/mapanything/datasets/utils/__init__.py b/mapanything/datasets/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/datasets/utils/data_splits.py b/mapanything/datasets/utils/data_splits.py new file mode 100644 index 0000000000000000000000000000000000000000..86983f4ff49396f686e60daac192e829d000de23 --- /dev/null +++ b/mapanything/datasets/utils/data_splits.py @@ -0,0 +1,1734 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Modules containing dataset split information +""" + + +class BlendedMVSSplits: + """ + This class contains the information about the BlendedMVS dataset splits. + """ + + def __init__(self): + """ + The splits are generated using the following logic: + # Get all seqls and seqhs using self.blendedmvs_info.all_sequences + all_sequences = self.blendedmvs_info.all_sequences + all_seqls = [int(seq[8:], 16) for seq in all_sequences] + all_seqhs = [int(seq[:8], 16) for seq in all_sequences] + # Split the seqls (& corresponding seqhs) using the DUSt3R train/val split logic + if split is None: + selection = slice(None) + elif split in ["train", "overfit"]: + # select 90% of all scenes + selection = [(seql % 10) > 0 for seql in all_seqls] + elif split == "val": + # select 10% of all scenes + selection = [(seql % 10) == 0 for seql in all_seqls] + else: + raise ValueError(f"Unknown split {split}, must be None, train, val or overfit") + # Filter sequences based on the selection + selected_seqls = [seql for seql, sel in zip(all_seqls, selection) if sel] + selected_seqhs = [seqh for seqh, sel in zip(all_seqhs, selection) if sel] + # Put them back into sequence names f"{seqh:08x}{seql:016x}" + sequence_names = [f"{seqh:08x}{seql:016x}" for seqh, seql in zip(selected_seqhs, selected_seqls)] + # Remove invalid sequence names which don't exist in self.blendedmvs_info.sequences + valid_sequences = set(self.blendedmvs_info.sequences) + valid_sequence_names = [name for name in sequence_names if name in valid_sequences] + """ + # All the 502 sequences in the dataset (totals to 115k images) + self.all_scenes = [ + "000000000000000000000000", + "00000000000000000000000a", + "00000000000000000000000b", + "00000000000000000000000c", + "00000000000000000000000d", + "00000000000000000000000e", + "00000000000000000000000f", + "000000000000000000000001", + "00000000000000000000001a", + "00000000000000000000001b", + "00000000000000000000001d", + "000000000000000000000002", + "000000000000000000000003", + "000000000000000000000004", + "000000000000000000000005", + "5a2a95f032a1c655cfe3de62", + "5a2af22b32a1c655cfe46013", + "5a2ba6de32a1c655cfe51b79", + "5a3b9731e24cd76dad1a5f1b", + "5a3ca9cb270f0e3f14d0eddb", + "5a3cb4e4270f0e3f14d12f43", + "5a03e732454a8a7ec672776c", + "5a3f4aba5889373fbbc5d3b5", + "5a4a38dad38c8a075495b5d2", + "5a5a1e48d62c7a12d5d00e47", + "5a6b1c418d100c2f8fdc4411", + "5a6feeb54a7fbc3f874f9db7", + "5a7cb1d6fe5c0d6fb53e64fb", + "5a7d3db14989e929563eb153", + "5a8aa0fab18050187cbe060e", + "5a9e5df65baeef72b4a021cd", + "5a48ba95c7dab83a7d7b44ed", + "5a48c4e9c7dab83a7d7b5cc7", + "5a48d4b2c7dab83a7d7b9851", + "5a69c47d0d5d0a7f3b2e9752", + "5a77b46b318efe6c6736e68a", + "5a355c271b63f53d5970f362", + "5a489fb1c7dab83a7d7b1070", + "5a533e8034d7582116e34209", + "5a562fc7425d0f5186314725", + "5a572fd9fc597b0478a81d14", + "5a588a8193ac3d233f77fbca", + "5a618c72784780334bc1972d", + "5a752d42acc41e2423f17674", + "5a969eea91dfc339a9a3ad2c", + "5a8315f624b8e938486e0bd8", + "5a57542f333d180827dfc132", + "5a0271884e62597cdee0d0eb", + "5a6400933d809f1d8200af15", + "5a6464143d809f1d8208c43c", + "5a563183425d0f5186314855", + "5aa0f9d7a9efce63548c69a1", + "5aa0f478a9efce63548c1cb4", + "5aa7db90bfdd572271e95246", + "5aa235f64a17b335eeaf9609", + "5aa515e613d42d091d29d300", + "5aa1196ea9efce63548ed649", + "5aaadd4cbc13235570d178a7", + "5ab6af12ac4291329b1072ab", + "5ab7e00aac4291329b15864d", + "5ab8b8e029f5351f7f2ccf59", + "5ab74bf2ac4291329b11e879", + "5ab85f1dac4291329b17cb50", + "5ab8713ba3799a1d138bd69a", + "5abc2506b53b042ead637d86", + "5acc7459a7853c4b5ebbef59", + "5acf8ca0f3d8a750097e4b15", + "5adc6bd52430a05ecb2ffb85", + "5ae2e9c5fe405c5076abc6b2", + "5af02e904c8216544b4ab5a2", + "5af28cea59bc705737003253", + "5af545d0559359053d25dcf5", + "5afacb69ab00705d0cefdd5b", + "5b2c67b5e0878c381608b8d8", + "5b3b2b9e8d46a939f933fdc0", + "5b3b353d8d46a939f93524b9", + "5b6e716d67b396324c2d77cb", + "5b6eff8b67b396324c5b2672", + "5b7a3890fc8fcf6781e2593a", + "5b21e18c58e2823a67a10dd8", + "5b60fa0c764f146feef84df0", + "5b69cc0cb44b61786eb959bf", + "5b78e57afc8fcf6781d0c3ba", + "5b192eb2170cf166458ff886", + "5b558a928bbfb62204e77ba2", + "5b864d850d072a699b32f4ae", + "5b908d3dc6ab78485f3d24a9", + "5b950c71608de421b1e7318f", + "5b4933abf2b5f44e95de482a", + "5b08286b2775267d5b0634ba", + "5b37189a35304b6f75e7583e", + "5b271079e0878c3816dacca4", + "5b22269758e2823a67a3bd03", + "5b62647143840965efc0dbde", + "5ba19a8a360c7c30c1c169df", + "5ba75d79d76ffa2c86cf2f05", + "5bb7a08aea1cfa39f1a947ab", + "5bb8a49aea1cfa39f1aa7f75", + "5bbb6eb2ea1cfa39f1af7e0c", + "5bc5f0e896b66a2cd8f9bd36", + "5bccd6beca24970bce448134", + "5bce7ac9ca24970bce4934b6", + "5bcf979a6d5f586b95c258cd", + "5bd43b4ba6b28b1ee86b92dd", + "5be3a5fb8cfdd56947f6b67c", + "5be3ae47f44e235bdbbc9771", + "5be4ab93870d330ff2dce134", + "5be47bf9b18881428d8fbc1d", + "5be883a4f98cee15019d5b83", + "5bea87f4abd34c35e1860ab5", + "5beb6e66abd34c35e18e66b9", + "5bf3a82cd439231948877aed", + "5bf7d63575c26f32dbf7413b", + "5bf17c0fd439231948355385", + "5bf26cbbd43923194854b270", + "5bf03590d4392319481971dc", + "5bf18642c50e6f7f8bdbd492", + "5bf21799d43923194842c001", + "5bfc9d5aec61ca1dd69132a2", + "5bfd0f32ec61ca1dd69dc77b", + "5bfe5ae0fe0ea555e6a969ca", + "5bff3c5cfe0ea555e6bcbf3a", + "5c0d13b795da9479e12e2ee9", + "5c1af2e2bee9a723c963d019", + "5c1b1500bee9a723c96c3e78", + "5c1dbf200843bc542d8ef8c4", + "5c1f33f1d33e1f2e4aa6dda4", + "5c2b3ed5e611832e8aed46bf", + "5c20ca3a0843bc542d94e3e2", + "5c062d84a96e33018ff6f0a6", + "5c189f2326173c3a09ed7ef3", + "5c1892f726173c3a09ea9aeb", + "5c34300a73a8df509add216d", + "5c34529873a8df509ae57b58", + "000000000000000000000006", + "000000000000000000000007", + "000000000000000000000008", + "000000000000000000000009", + "000000000000000000000010", + "000000000000000000000011", + "000000000000000000000012", + "000000000000000000000015", + "000000000000000000000016", + "000000000000000000000017", + "000000000000000000000018", + "000000000000000000000019", + "56d73ba74bd29b8c35abade2", + "56f34064e296120e10484dc4", + "57a4a7bb6b9272286e26dc18", + "57f8d9bbe73f6760f10e916a", + "58a0a2f33d0b4542479a11b1", + "58a0dd1a3d0b4542479a28f3", + "58a1a7914a4d262a170b1101", + "58a1bc804a4d262a170b2f01", + "58a1d9d14a4d262a170b58fe", + "58a01dea38486e3c98475871", + "58a1f5d74a4d262a170b65fc", + "58a2a09e156b87103d3d668c", + "58a2d9c3156b87103d3da90f", + "58a3ccb0156b87103d3e4332", + "58a3f2f8156b87103d3e5838", + "58a3f6c0156b87103d3e5971", + "58a3fc95156b87103d3e5d9b", + "58a07ce53d0b45424799fdde", + "58a07f233d0b45424799ffe7", + "58a44df2156b87103d3ee239", + "58a164f73d0b4542479a7a8e", + "58a0365e38486e3c984783eb", + "58a439cf156b87103d3ec885", + "58a464aa156b87103d3eec04", + "58a4452f156b87103d3ed55b", + "58a160983d0b4542479a7347", + "58a186444a4d262a170ae3ae", + "58a285424a4d262a170baf3e", + "58a41819156b87103d3e92a5", + "58a44463156b87103d3ed45e", + "58a47552156b87103d3f00a4", + "58c4bb4f4a69c55606122be4", + "58c6451e4a69c556061894f1", + "58ca7014affdfd07c70a95ce", + "58cf4771d0f5fb221defe6da", + "58d36897f387231e6c929903", + "58eaf1513353456af3a1682a", + "58f7f7299f5b5647873cb110", + "58f73e7c9f5b56478738929f", + "59a8f851597729752c31e7e0", + "59a452bf9b460239aa5d1c72", + "59a9619a825418241fb88191", + "59acd2f4b891807f439c8992", + "59bf97fe7e7b31545da34439", + "59c1c3e2fd6e3d4ead9f1013", + "59d2657f82ca7774b1ec081d", + "59da1fb88a126011d0394ae9", + "59e75a2ca9e91f2c5526005d", + "59e864b2a9e91f2c5529325f", + "59ecfd02e225f6492d20fcc9", + "59f37f74b45be2233001ba18", + "59f70ab1e5c5d366af29bf3e", + "59f87d0bfa6280566fb38c9a", + "59f363a8b45be22330016cad", + "564a27b26d07883f460d8ab0", + "565fb1dead14d4154dae2b94", + "567a0fb0a825d2fb79ac9a20", + "569b92eb826bcba945ca002b", + "576fefa017ce5a16397e87fd", + "584a7333fe3cb463906c9fe6", + "584aa8e9fe3cb463906cc7d0", + "584ad76bfe3cb463906ce6dc", + "584af003fe3cb463906d0e9b", + "584b9a747072670e72bfc49d", + "584b671f7072670e72bfaaf8", + "584b81747072670e72bfbbfd", + "584ba35f7072670e72bfca4d", + "584ba5977072670e72bfcc2d", + "584bc53c7072670e72bfe85f", + "584bc3997072670e72bfe58d", + "584bc4407072670e72bfe665", + "584bd5587072670e72bffe39", + "584bdadf7072670e72c0005c", + "584be5ed7072670e72c007b3", + "584c9ad27072670e72c060c5", + "584c9cc67072670e72c063a1", + "584c58b77072670e72c03990", + "584cea557072670e72c07fb4", + "584d19d47072670e72c0c6c0", + "584dfe467072670e72c1665a", + "584e875c7072670e72c1ec94", + "584e05667072670e72c17167", + "584f94e87072670e72c2d3f7", + "584fdffd7072670e72c32dc7", + "584fe07f7072670e72c32e59", + "585a2a71b338a62ad50138dc", + "585a206ab338a62ad501298f", + "585a217cb338a62ad5012b38", + "585b34afb338a62ad501e836", + "585bb25fc49c8507c3ce7812", + "585bbe55c49c8507c3ce81cd", + "585d6c8a2a57cc11d4920a1e", + "585e54c72a57cc11d492f71a", + "585e34302a57cc11d492be30", + "585ee0632a57cc11d4933608", + "585f9661712e2761468dabca", + "585ffe9a712e2761468df643", + "586a37ec9d1b5e34c28184fc", + "586a515a9d1b5e34c281b431", + "586a94939d1b5e34c2823b5d", + "586abc689d1b5e34c2826360", + "586b0e219d1b5e34c2828862", + "586b3db89d1b5e34c282cd52", + "586b4c459d1b5e34c282e66d", + "586b7d7d9d1b5e34c283359e", + "586b8f149d1b5e34c283497c", + "586b8f629d1b5e34c28349d6", + "586c4c4d9d1b5e34c28391a1", + "586c5b5b9d1b5e34c2839a5b", + "586c9fdf9d1b5e34c283b657", + "586c48329d1b5e34c2838e80", + "586caab99d1b5e34c283c213", + "586cd0779d1b5e34c28403a7", + "586d6d249d1b5e34c284b80e", + "586d8a029d1b5e34c284c948", + "586d55af9d1b5e34c284a999", + "586d07869d1b5e34c2842e5b", + "586d27489d1b5e34c28453af", + "586df9849d1b5e34c28506de", + "586e279c9d1b5e34c2852180", + "587bc5ec2366dd5d06e262c1", + "587c1abf2366dd5d06e28901", + "587c03f12366dd5d06e27722", + "587c19da2366dd5d06e2877b", + "587c31b92366dd5d06e2a9dc", + "587c87d02366dd5d06e2f989", + "587c97a52366dd5d06e30a96", + "587c45192366dd5d06e2c0eb", + "587cec702366dd5d06e37862", + "587cef0a2366dd5d06e379e3", + "587db5872366dd5d06e3e0af", + "587e2b1d2366dd5d06e41af0", + "587e2ea62366dd5d06e41f2e", + "587e5cb52366dd5d06e4486e", + "587eb1822366dd5d06e45f29", + "587f365d2366dd5d06e4906e", + "588a9c5fec4d5a1c088ec350", + "588a34cfec4d5a1c088ea8d1", + "588ab5bdec4d5a1c088ed60f", + "588aff9d90414422fbe7885a", + "588b20d290414422fbe79f40", + "588c08d590414422fbe8200b", + "588c203d90414422fbe8319e", + "588c989a90414422fbe86d96", + "588ca09d90414422fbe871a1", + "588cce2190414422fbe88520", + "588cd5ef90414422fbe8875c", + "588cf0ad90414422fbe8a20f", + "588e0d8c90414422fbe8f8b2", + "588e01c490414422fbe8ee2a", + "588e35e690414422fbe90a53", + "588f017e90414422fbe9b74b", + "588f095190414422fbe9c1ee", + "589aca717dc3d323d55671c4", + "589af2c97dc3d323d55691e8", + "589b49ea7dc3d323d556d9b4", + "589b04287dc3d323d556a185", + "589bf6a57dc3d323d55743ab", + "589c3c497dc3d323d5578468", + "589c3c577dc3d323d5578480", + "589c300f7dc3d323d5577926", + "589c24527dc3d323d5577126", + "589c35457dc3d323d5577d8d", + "589ca6a6b896147a1b73aff7", + "589d1e1fb896147a1b73ee5b", + "589d5c58b896147a1b742256", + "589d95538fa2cf375df3317b", + "589df0ffb504a864ad63521a", + "589ea316b504a864ad639a2b", + "589ec97cb504a864ad63adc3", + "589f214338486e3c9846f123", + "589fdfe738486e3c984736cf", + "590c2d70336bb52a190be886", + "590f91851225725be9e25d4e", + "591a467a6109e14d4f09b776", + "591cf3033162411cf9047f37", + "591ea44850991c70dc99a207", + "599aa591d5b41f366fed0d58", + "5643df56138263b51db1b5f3", + "5644bdac138263b51db9f669", + "5692a4c2adafac1f14201821", + "5850d4f97072670e72c425d6", + "5854c405804be105852330fe", + "5855a4fc804be1058523bd75", + "5856ac15804be105852419d8", + "5856ae8b804be10585241bae", + "5856b460804be10585242059", + "5857aa5ab338a62ad5ff4dbe", + "5857acf8b338a62ad5ff5107", + "5858db6cb338a62ad500103b", + "5858dbcab338a62ad5001081", + "5859d84fb338a62ad500e5cf", + "5861d8ea712e2761468f3cb3", + "5863edf8712e27614690cce0", + "5864a935712e2761469111b4", + "5864b076712e27614691197e", + "5864da88712e276146913d8b", + "5865f4a8712e27614691e39b", + "5867a434833dfe3f7b88edaf", + "5868cd15833dfe3f7b89bfa3", + "5880b3692366dd5d06e5d534", + "5880e3422366dd5d06e5ff8e", + "5880f0ef2366dd5d06e6166e", + "5881d2bfb6844814c136a119", + "5881f11d8ce2c2754d0714c3", + "5881fee18ce2c2754d0723f8", + "5882cda2b116682b4adebd25", + "5882d58fb116682b4adec7db", + "5884c256932ba84fbed70bf5", + "5884cc13932ba84fbed71ec4", + "5885bc5296fa095e0671a7f0", + "5886d14cb791366d617a362c", + "5888becfc02346100f4b0b21", + "5888e408c02346100f4b1a29", + "5889da66ec4d5a1c088e5187", + "5889e344ec4d5a1c088e59be", + "5889e754ec4d5a1c088e60ba", + "5890c16b90414422fbeb0262", + "5891d8ae9a8c0314c5cd30ab", + "5891d0479a8c0314c5cd2abd", + "5891ecf19a8c0314c5cd490a", + "5892c0cd9a8c0314c5cdc977", + "5894ab309a8c0314c5cee57d", + "5895a6a89a8c0314c5cfca7c", + "5895b8c29a8c0314c5cfd051", + "5895d38f9a8c0314c5cfe50c", + "5895f2329a8c0314c5d00117", + "5896bb989a8c0314c5d086b6", + "5896ebf39a8c0314c5d0a8c4", + "5898b1bac9dccc22987b7f74", + "5898b6ffc9dccc22987b8a03", + "5898b31cc9dccc22987b82ec", + "5898bbaac9dccc22987b8eba", + "5899cfa6b76d7a3780a4cb64", + "5899e5dcb76d7a3780a4ecc1", + "5947b62af1b45630bd0c2a02", + "57102be2877e1421026358af", + "57153d4031bb9900425bde85", + "57177cd7fb8d93461afc4527", + "58497cdf97b73e0b090c4273", + "58500b007072670e72c35588", + "58510bf97072670e72c46ddf", + "58522bd56789802282f2ecb3", + "58524a2e0e7012308944bcf3", + "58524a080e7012308944bcbf", + "58524c1d0e7012308944bfda", + "58524f170e7012308944c200", + "58529a4e0e70123089454c6f", + "58551bdf804be1058523556d", + "58568c9a804be10585240b03", + "58574b35804be105852455fd", + "58577c60b338a62ad5ff1564", + "58592d69b338a62ad5007a74", + "58598db2b338a62ad500bc38", + "58625f42712e2761468fb44c", + "58651bcc712e2761469166dc", + "58660e79712e27614691fe3d", + "58669aad712e27614692834c", + "58669c02712e27614692851a", + "58676c36833dfe3f7b88b7f2", + "58678b2d833dfe3f7b88e244", + "58790c82ce911104a3467c88", + "58800b0b2366dd5d06e5312d", + "58805eac2366dd5d06e56460", + "58806e422366dd5d06e57bb6", + "58831d060db9bf59bf8ab98b", + "58851ebb932ba84fbed7abad", + "58871dc3b791366d617a55ff", + "58873cabb791366d617a65a7", + "58873d44b791366d617a65dd", + "58888b3dc02346100f4af665", + "58897f62c02346100f4b8ee6", + "58933bac9a8c0314c5ce3508", + "58938e6d9a8c0314c5ce726f", + "58951cb49a8c0314c5cf4d5e", + "58970fd09a8c0314c5d0e383", + "58977ef09a8c0314c5d17b26", + "59056e6760bb961de55f3501", + "59071f2e5a6dbd3af4130f98", + "59102c811225725be9e64149", + "59338e76772c3e6384afbb15", + "59350ca084b7f26bf5ce6eb8", + "59397e493a87372f2c9e882b", + "59521e0b9096412211c2aa9d", + "59817e4a1bd4b175e7038d19", + "567884f58d2828b95e3c8eba", + "585559d9804be10585238ddf", + "585834cdb338a62ad5ffab4d", + "586082d8712e2761468e2877", + "586133c2712e2761468ecfe3", + "586281d2712e2761468fcaa2", + "586316e5712e276146903c4d", + "586326ad712e276146904571", + "586375c9712e276146907429", + "586389c9712e276146908da6", + "586496fa712e2761469108e7", + "586669c6712e27614692597a", + "586913a49d1b5e34c2808b02", + "586922da9d1b5e34c2809ff3", + "588185d8dfb7a15588a114a3", + "588305ed0db9bf59bf8a8c80", + "588315c60db9bf59bf8aa928", + "588332ee0db9bf59bf8ae9c3", + "588457b8932ba84fbed69942", + "588519d5932ba84fbed7a04a", + "588824d1b791366d617adeef", + "588857f6c02346100f4ac09f", + "589145ef90414422fbeb2e08", + "589433fa9a8c0314c5ce9656", + "589765d39a8c0314c5d16b12", + "5851165f7072670e72c4860d", + "5859341ab338a62ad500848d", + "5862388b712e2761468f84aa", + "5863915b712e276146909135", + "5866445b712e27614692383e", + "5866500d712e2761469240fd", + "5867785a833dfe3f7b88c764", + "5867969c833dfe3f7b88e8bc", + "5868040c833dfe3f7b8934f7", + "5880675a2366dd5d06e570ca", + "5882372c8ce2c2754d076af0", + "5883535e932ba84fbed5ad07", + "5888358cb791366d617af69d", + "5890330d90414422fbeaa0cb", + "5897076e9a8c0314c5d0d31b", + "5940564ec2d9527ab869f7e2", + "5947719bf1b45630bd096665", + "5948194ff1b45630bd0f47e3", + "5950206a41b158666ac50506", + "5983012d1bd4b175e70c985a", + "58586810b338a62ad5ffc20c", + "58592046b338a62ad5006b33", + "58592854b338a62ad500750a", + "58596531b338a62ad500aace", + "58818685dfb7a15588a11626", + "58829563f42b1d3ee3ec835f", + "58894345c02346100f4b51ca", + "585289980e7012308945276a", + "585369770e7012308945c709", + "585373640e7012308945cab9", + "588230658ce2c2754d076728", + "589388059a8c0314c5ce718b", + "595979485ec6a95e86a58c8d", + "5841206219d291325678ca90", + "58563650804be1058523da55", + "58564084804be1058523e116", + "58636467712e27614690661f", + "58647495712e27614690f36d", + "58654563712e276146918643", + "58664251712e276146923738", + "588084032366dd5d06e59e82", + "588159582366dd5d06e66877", + "5890279190414422fbea9734", + "5890523090414422fbeab3f0", + "5890641690414422fbeabbe7", + "585203546789802282f2aaf5", + ] + + # Final sequences to be used after filtering (some of the sequences have incorrect/low quality depth) + # Generally water bodies like lakes have incorrect depth + # Filtered out sequences: + # "5692a4c2adafac1f14201821" # Incorrect Depth + # "5864a935712e2761469111b4" # Noisy Depth and artifacts near horizon + # "59f87d0bfa6280566fb38c9a" # Object-centric, noise with background and sometimes in front of object + # "58a44463156b87103d3ed45e" # Very noisy depth in background + # "5c2b3ed5e611832e8aed46bf" # Depth occluded by artifacts + # "5bf03590d4392319481971dc" # Depth occluded by artifacts + # "00000000000000000000001a" # Largely incomplete depth + # "00000000000000000000000c" # Imprecise depth for buildings + # "000000000000000000000000" # Incorrect depth for planar terrain + self.scenes = [ + "00000000000000000000000a", + "00000000000000000000000b", + "00000000000000000000000d", + "00000000000000000000000e", + "00000000000000000000000f", + "000000000000000000000001", + "00000000000000000000001b", + "00000000000000000000001d", + "000000000000000000000002", + "000000000000000000000003", + "000000000000000000000004", + "000000000000000000000005", + "5a2a95f032a1c655cfe3de62", + "5a2af22b32a1c655cfe46013", + "5a2ba6de32a1c655cfe51b79", + "5a3b9731e24cd76dad1a5f1b", + "5a3ca9cb270f0e3f14d0eddb", + "5a3cb4e4270f0e3f14d12f43", + "5a03e732454a8a7ec672776c", + "5a3f4aba5889373fbbc5d3b5", + "5a4a38dad38c8a075495b5d2", + "5a5a1e48d62c7a12d5d00e47", + "5a6b1c418d100c2f8fdc4411", + "5a6feeb54a7fbc3f874f9db7", + "5a7cb1d6fe5c0d6fb53e64fb", + "5a7d3db14989e929563eb153", + "5a8aa0fab18050187cbe060e", + "5a9e5df65baeef72b4a021cd", + "5a48ba95c7dab83a7d7b44ed", + "5a48c4e9c7dab83a7d7b5cc7", + "5a48d4b2c7dab83a7d7b9851", + "5a69c47d0d5d0a7f3b2e9752", + "5a77b46b318efe6c6736e68a", + "5a355c271b63f53d5970f362", + "5a489fb1c7dab83a7d7b1070", + "5a533e8034d7582116e34209", + "5a562fc7425d0f5186314725", + "5a572fd9fc597b0478a81d14", + "5a588a8193ac3d233f77fbca", + "5a618c72784780334bc1972d", + "5a752d42acc41e2423f17674", + "5a969eea91dfc339a9a3ad2c", + "5a8315f624b8e938486e0bd8", + "5a57542f333d180827dfc132", + "5a0271884e62597cdee0d0eb", + "5a6400933d809f1d8200af15", + "5a6464143d809f1d8208c43c", + "5a563183425d0f5186314855", + "5aa0f9d7a9efce63548c69a1", + "5aa0f478a9efce63548c1cb4", + "5aa7db90bfdd572271e95246", + "5aa235f64a17b335eeaf9609", + "5aa515e613d42d091d29d300", + "5aa1196ea9efce63548ed649", + "5aaadd4cbc13235570d178a7", + "5ab6af12ac4291329b1072ab", + "5ab7e00aac4291329b15864d", + "5ab8b8e029f5351f7f2ccf59", + "5ab74bf2ac4291329b11e879", + "5ab85f1dac4291329b17cb50", + "5ab8713ba3799a1d138bd69a", + "5abc2506b53b042ead637d86", + "5acc7459a7853c4b5ebbef59", + "5acf8ca0f3d8a750097e4b15", + "5adc6bd52430a05ecb2ffb85", + "5ae2e9c5fe405c5076abc6b2", + "5af02e904c8216544b4ab5a2", + "5af28cea59bc705737003253", + "5af545d0559359053d25dcf5", + "5afacb69ab00705d0cefdd5b", + "5b2c67b5e0878c381608b8d8", + "5b3b2b9e8d46a939f933fdc0", + "5b3b353d8d46a939f93524b9", + "5b6e716d67b396324c2d77cb", + "5b6eff8b67b396324c5b2672", + "5b7a3890fc8fcf6781e2593a", + "5b21e18c58e2823a67a10dd8", + "5b60fa0c764f146feef84df0", + "5b69cc0cb44b61786eb959bf", + "5b78e57afc8fcf6781d0c3ba", + "5b192eb2170cf166458ff886", + "5b558a928bbfb62204e77ba2", + "5b864d850d072a699b32f4ae", + "5b908d3dc6ab78485f3d24a9", + "5b950c71608de421b1e7318f", + "5b4933abf2b5f44e95de482a", + "5b08286b2775267d5b0634ba", + "5b37189a35304b6f75e7583e", + "5b271079e0878c3816dacca4", + "5b22269758e2823a67a3bd03", + "5b62647143840965efc0dbde", + "5ba19a8a360c7c30c1c169df", + "5ba75d79d76ffa2c86cf2f05", + "5bb7a08aea1cfa39f1a947ab", + "5bb8a49aea1cfa39f1aa7f75", + "5bbb6eb2ea1cfa39f1af7e0c", + "5bc5f0e896b66a2cd8f9bd36", + "5bccd6beca24970bce448134", + "5bce7ac9ca24970bce4934b6", + "5bcf979a6d5f586b95c258cd", + "5bd43b4ba6b28b1ee86b92dd", + "5be3a5fb8cfdd56947f6b67c", + "5be3ae47f44e235bdbbc9771", + "5be4ab93870d330ff2dce134", + "5be47bf9b18881428d8fbc1d", + "5be883a4f98cee15019d5b83", + "5bea87f4abd34c35e1860ab5", + "5beb6e66abd34c35e18e66b9", + "5bf3a82cd439231948877aed", + "5bf7d63575c26f32dbf7413b", + "5bf17c0fd439231948355385", + "5bf26cbbd43923194854b270", + "5bf18642c50e6f7f8bdbd492", + "5bf21799d43923194842c001", + "5bfc9d5aec61ca1dd69132a2", + "5bfd0f32ec61ca1dd69dc77b", + "5bfe5ae0fe0ea555e6a969ca", + "5bff3c5cfe0ea555e6bcbf3a", + "5c0d13b795da9479e12e2ee9", + "5c1af2e2bee9a723c963d019", + "5c1b1500bee9a723c96c3e78", + "5c1dbf200843bc542d8ef8c4", + "5c1f33f1d33e1f2e4aa6dda4", + "5c20ca3a0843bc542d94e3e2", + "5c062d84a96e33018ff6f0a6", + "5c189f2326173c3a09ed7ef3", + "5c1892f726173c3a09ea9aeb", + "5c34300a73a8df509add216d", + "5c34529873a8df509ae57b58", + "000000000000000000000006", + "000000000000000000000007", + "000000000000000000000008", + "000000000000000000000009", + "000000000000000000000010", + "000000000000000000000011", + "000000000000000000000012", + "000000000000000000000015", + "000000000000000000000016", + "000000000000000000000017", + "000000000000000000000018", + "000000000000000000000019", + "56d73ba74bd29b8c35abade2", + "56f34064e296120e10484dc4", + "57a4a7bb6b9272286e26dc18", + "57f8d9bbe73f6760f10e916a", + "58a0a2f33d0b4542479a11b1", + "58a0dd1a3d0b4542479a28f3", + "58a1a7914a4d262a170b1101", + "58a1bc804a4d262a170b2f01", + "58a1d9d14a4d262a170b58fe", + "58a01dea38486e3c98475871", + "58a1f5d74a4d262a170b65fc", + "58a2a09e156b87103d3d668c", + "58a2d9c3156b87103d3da90f", + "58a3ccb0156b87103d3e4332", + "58a3f2f8156b87103d3e5838", + "58a3f6c0156b87103d3e5971", + "58a3fc95156b87103d3e5d9b", + "58a07ce53d0b45424799fdde", + "58a07f233d0b45424799ffe7", + "58a44df2156b87103d3ee239", + "58a164f73d0b4542479a7a8e", + "58a0365e38486e3c984783eb", + "58a439cf156b87103d3ec885", + "58a464aa156b87103d3eec04", + "58a4452f156b87103d3ed55b", + "58a160983d0b4542479a7347", + "58a186444a4d262a170ae3ae", + "58a285424a4d262a170baf3e", + "58a41819156b87103d3e92a5", + "58a47552156b87103d3f00a4", + "58c4bb4f4a69c55606122be4", + "58c6451e4a69c556061894f1", + "58ca7014affdfd07c70a95ce", + "58cf4771d0f5fb221defe6da", + "58d36897f387231e6c929903", + "58eaf1513353456af3a1682a", + "58f7f7299f5b5647873cb110", + "58f73e7c9f5b56478738929f", + "59a8f851597729752c31e7e0", + "59a452bf9b460239aa5d1c72", + "59a9619a825418241fb88191", + "59acd2f4b891807f439c8992", + "59bf97fe7e7b31545da34439", + "59c1c3e2fd6e3d4ead9f1013", + "59d2657f82ca7774b1ec081d", + "59da1fb88a126011d0394ae9", + "59e75a2ca9e91f2c5526005d", + "59e864b2a9e91f2c5529325f", + "59ecfd02e225f6492d20fcc9", + "59f37f74b45be2233001ba18", + "59f70ab1e5c5d366af29bf3e", + "59f363a8b45be22330016cad", + "564a27b26d07883f460d8ab0", + "565fb1dead14d4154dae2b94", + "567a0fb0a825d2fb79ac9a20", + "569b92eb826bcba945ca002b", + "576fefa017ce5a16397e87fd", + "584a7333fe3cb463906c9fe6", + "584aa8e9fe3cb463906cc7d0", + "584ad76bfe3cb463906ce6dc", + "584af003fe3cb463906d0e9b", + "584b9a747072670e72bfc49d", + "584b671f7072670e72bfaaf8", + "584b81747072670e72bfbbfd", + "584ba35f7072670e72bfca4d", + "584ba5977072670e72bfcc2d", + "584bc53c7072670e72bfe85f", + "584bc3997072670e72bfe58d", + "584bc4407072670e72bfe665", + "584bd5587072670e72bffe39", + "584bdadf7072670e72c0005c", + "584be5ed7072670e72c007b3", + "584c9ad27072670e72c060c5", + "584c9cc67072670e72c063a1", + "584c58b77072670e72c03990", + "584cea557072670e72c07fb4", + "584d19d47072670e72c0c6c0", + "584dfe467072670e72c1665a", + "584e875c7072670e72c1ec94", + "584e05667072670e72c17167", + "584f94e87072670e72c2d3f7", + "584fdffd7072670e72c32dc7", + "584fe07f7072670e72c32e59", + "585a2a71b338a62ad50138dc", + "585a206ab338a62ad501298f", + "585a217cb338a62ad5012b38", + "585b34afb338a62ad501e836", + "585bb25fc49c8507c3ce7812", + "585bbe55c49c8507c3ce81cd", + "585d6c8a2a57cc11d4920a1e", + "585e54c72a57cc11d492f71a", + "585e34302a57cc11d492be30", + "585ee0632a57cc11d4933608", + "585f9661712e2761468dabca", + "585ffe9a712e2761468df643", + "586a37ec9d1b5e34c28184fc", + "586a515a9d1b5e34c281b431", + "586a94939d1b5e34c2823b5d", + "586abc689d1b5e34c2826360", + "586b0e219d1b5e34c2828862", + "586b3db89d1b5e34c282cd52", + "586b4c459d1b5e34c282e66d", + "586b7d7d9d1b5e34c283359e", + "586b8f149d1b5e34c283497c", + "586b8f629d1b5e34c28349d6", + "586c4c4d9d1b5e34c28391a1", + "586c5b5b9d1b5e34c2839a5b", + "586c9fdf9d1b5e34c283b657", + "586c48329d1b5e34c2838e80", + "586caab99d1b5e34c283c213", + "586cd0779d1b5e34c28403a7", + "586d6d249d1b5e34c284b80e", + "586d8a029d1b5e34c284c948", + "586d55af9d1b5e34c284a999", + "586d07869d1b5e34c2842e5b", + "586d27489d1b5e34c28453af", + "586df9849d1b5e34c28506de", + "586e279c9d1b5e34c2852180", + "587bc5ec2366dd5d06e262c1", + "587c1abf2366dd5d06e28901", + "587c03f12366dd5d06e27722", + "587c19da2366dd5d06e2877b", + "587c31b92366dd5d06e2a9dc", + "587c87d02366dd5d06e2f989", + "587c97a52366dd5d06e30a96", + "587c45192366dd5d06e2c0eb", + "587cec702366dd5d06e37862", + "587cef0a2366dd5d06e379e3", + "587db5872366dd5d06e3e0af", + "587e2b1d2366dd5d06e41af0", + "587e2ea62366dd5d06e41f2e", + "587e5cb52366dd5d06e4486e", + "587eb1822366dd5d06e45f29", + "587f365d2366dd5d06e4906e", + "588a9c5fec4d5a1c088ec350", + "588a34cfec4d5a1c088ea8d1", + "588ab5bdec4d5a1c088ed60f", + "588aff9d90414422fbe7885a", + "588b20d290414422fbe79f40", + "588c08d590414422fbe8200b", + "588c203d90414422fbe8319e", + "588c989a90414422fbe86d96", + "588ca09d90414422fbe871a1", + "588cce2190414422fbe88520", + "588cd5ef90414422fbe8875c", + "588cf0ad90414422fbe8a20f", + "588e0d8c90414422fbe8f8b2", + "588e01c490414422fbe8ee2a", + "588e35e690414422fbe90a53", + "588f017e90414422fbe9b74b", + "588f095190414422fbe9c1ee", + "589aca717dc3d323d55671c4", + "589af2c97dc3d323d55691e8", + "589b49ea7dc3d323d556d9b4", + "589b04287dc3d323d556a185", + "589bf6a57dc3d323d55743ab", + "589c3c497dc3d323d5578468", + "589c3c577dc3d323d5578480", + "589c300f7dc3d323d5577926", + "589c24527dc3d323d5577126", + "589c35457dc3d323d5577d8d", + "589ca6a6b896147a1b73aff7", + "589d1e1fb896147a1b73ee5b", + "589d5c58b896147a1b742256", + "589d95538fa2cf375df3317b", + "589df0ffb504a864ad63521a", + "589ea316b504a864ad639a2b", + "589ec97cb504a864ad63adc3", + "589f214338486e3c9846f123", + "589fdfe738486e3c984736cf", + "590c2d70336bb52a190be886", + "590f91851225725be9e25d4e", + "591a467a6109e14d4f09b776", + "591cf3033162411cf9047f37", + "591ea44850991c70dc99a207", + "599aa591d5b41f366fed0d58", + "5643df56138263b51db1b5f3", + "5644bdac138263b51db9f669", + "5850d4f97072670e72c425d6", + "5854c405804be105852330fe", + "5855a4fc804be1058523bd75", + "5856ac15804be105852419d8", + "5856ae8b804be10585241bae", + "5856b460804be10585242059", + "5857aa5ab338a62ad5ff4dbe", + "5857acf8b338a62ad5ff5107", + "5858db6cb338a62ad500103b", + "5858dbcab338a62ad5001081", + "5859d84fb338a62ad500e5cf", + "5861d8ea712e2761468f3cb3", + "5863edf8712e27614690cce0", + "5864b076712e27614691197e", + "5864da88712e276146913d8b", + "5865f4a8712e27614691e39b", + "5867a434833dfe3f7b88edaf", + "5868cd15833dfe3f7b89bfa3", + "5880b3692366dd5d06e5d534", + "5880e3422366dd5d06e5ff8e", + "5880f0ef2366dd5d06e6166e", + "5881d2bfb6844814c136a119", + "5881f11d8ce2c2754d0714c3", + "5881fee18ce2c2754d0723f8", + "5882cda2b116682b4adebd25", + "5882d58fb116682b4adec7db", + "5884c256932ba84fbed70bf5", + "5884cc13932ba84fbed71ec4", + "5885bc5296fa095e0671a7f0", + "5886d14cb791366d617a362c", + "5888becfc02346100f4b0b21", + "5888e408c02346100f4b1a29", + "5889da66ec4d5a1c088e5187", + "5889e344ec4d5a1c088e59be", + "5889e754ec4d5a1c088e60ba", + "5890c16b90414422fbeb0262", + "5891d8ae9a8c0314c5cd30ab", + "5891d0479a8c0314c5cd2abd", + "5891ecf19a8c0314c5cd490a", + "5892c0cd9a8c0314c5cdc977", + "5894ab309a8c0314c5cee57d", + "5895a6a89a8c0314c5cfca7c", + "5895b8c29a8c0314c5cfd051", + "5895d38f9a8c0314c5cfe50c", + "5895f2329a8c0314c5d00117", + "5896bb989a8c0314c5d086b6", + "5896ebf39a8c0314c5d0a8c4", + "5898b1bac9dccc22987b7f74", + "5898b6ffc9dccc22987b8a03", + "5898b31cc9dccc22987b82ec", + "5898bbaac9dccc22987b8eba", + "5899cfa6b76d7a3780a4cb64", + "5899e5dcb76d7a3780a4ecc1", + "5947b62af1b45630bd0c2a02", + "57102be2877e1421026358af", + "57153d4031bb9900425bde85", + "57177cd7fb8d93461afc4527", + "58497cdf97b73e0b090c4273", + "58500b007072670e72c35588", + "58510bf97072670e72c46ddf", + "58522bd56789802282f2ecb3", + "58524a2e0e7012308944bcf3", + "58524a080e7012308944bcbf", + "58524c1d0e7012308944bfda", + "58524f170e7012308944c200", + "58529a4e0e70123089454c6f", + "58551bdf804be1058523556d", + "58568c9a804be10585240b03", + "58574b35804be105852455fd", + "58577c60b338a62ad5ff1564", + "58592d69b338a62ad5007a74", + "58598db2b338a62ad500bc38", + "58625f42712e2761468fb44c", + "58651bcc712e2761469166dc", + "58660e79712e27614691fe3d", + "58669aad712e27614692834c", + "58669c02712e27614692851a", + "58676c36833dfe3f7b88b7f2", + "58678b2d833dfe3f7b88e244", + "58790c82ce911104a3467c88", + "58800b0b2366dd5d06e5312d", + "58805eac2366dd5d06e56460", + "58806e422366dd5d06e57bb6", + "58831d060db9bf59bf8ab98b", + "58851ebb932ba84fbed7abad", + "58871dc3b791366d617a55ff", + "58873cabb791366d617a65a7", + "58873d44b791366d617a65dd", + "58888b3dc02346100f4af665", + "58897f62c02346100f4b8ee6", + "58933bac9a8c0314c5ce3508", + "58938e6d9a8c0314c5ce726f", + "58951cb49a8c0314c5cf4d5e", + "58970fd09a8c0314c5d0e383", + "58977ef09a8c0314c5d17b26", + "59056e6760bb961de55f3501", + "59071f2e5a6dbd3af4130f98", + "59102c811225725be9e64149", + "59338e76772c3e6384afbb15", + "59350ca084b7f26bf5ce6eb8", + "59397e493a87372f2c9e882b", + "59521e0b9096412211c2aa9d", + "59817e4a1bd4b175e7038d19", + "567884f58d2828b95e3c8eba", + "585559d9804be10585238ddf", + "585834cdb338a62ad5ffab4d", + "586082d8712e2761468e2877", + "586133c2712e2761468ecfe3", + "586281d2712e2761468fcaa2", + "586316e5712e276146903c4d", + "586326ad712e276146904571", + "586375c9712e276146907429", + "586389c9712e276146908da6", + "586496fa712e2761469108e7", + "586669c6712e27614692597a", + "586913a49d1b5e34c2808b02", + "586922da9d1b5e34c2809ff3", + "588185d8dfb7a15588a114a3", + "588305ed0db9bf59bf8a8c80", + "588315c60db9bf59bf8aa928", + "588332ee0db9bf59bf8ae9c3", + "588457b8932ba84fbed69942", + "588519d5932ba84fbed7a04a", + "588824d1b791366d617adeef", + "588857f6c02346100f4ac09f", + "589145ef90414422fbeb2e08", + "589433fa9a8c0314c5ce9656", + "589765d39a8c0314c5d16b12", + "5851165f7072670e72c4860d", + "5859341ab338a62ad500848d", + "5862388b712e2761468f84aa", + "5863915b712e276146909135", + "5866445b712e27614692383e", + "5866500d712e2761469240fd", + "5867785a833dfe3f7b88c764", + "5867969c833dfe3f7b88e8bc", + "5868040c833dfe3f7b8934f7", + "5880675a2366dd5d06e570ca", + "5882372c8ce2c2754d076af0", + "5883535e932ba84fbed5ad07", + "5888358cb791366d617af69d", + "5890330d90414422fbeaa0cb", + "5897076e9a8c0314c5d0d31b", + "5940564ec2d9527ab869f7e2", + "5947719bf1b45630bd096665", + "5948194ff1b45630bd0f47e3", + "5950206a41b158666ac50506", + "5983012d1bd4b175e70c985a", + "58586810b338a62ad5ffc20c", + "58592046b338a62ad5006b33", + "58592854b338a62ad500750a", + "58596531b338a62ad500aace", + "58818685dfb7a15588a11626", + "58829563f42b1d3ee3ec835f", + "58894345c02346100f4b51ca", + "585289980e7012308945276a", + "585369770e7012308945c709", + "585373640e7012308945cab9", + "588230658ce2c2754d076728", + "589388059a8c0314c5ce718b", + "595979485ec6a95e86a58c8d", + "5841206219d291325678ca90", + "58563650804be1058523da55", + "58564084804be1058523e116", + "58636467712e27614690661f", + "58647495712e27614690f36d", + "58654563712e276146918643", + "58664251712e276146923738", + "588084032366dd5d06e59e82", + "588159582366dd5d06e66877", + "5890279190414422fbea9734", + "5890523090414422fbeab3f0", + "5890641690414422fbeabbe7", + "585203546789802282f2aaf5", + ] + + # Train set sequences after filtering + self.train_split_scenes = [ + "00000000000000000000000b", + "00000000000000000000000d", + "00000000000000000000000e", + "00000000000000000000000f", + "000000000000000000000001", + "00000000000000000000001b", + "00000000000000000000001d", + "000000000000000000000002", + "000000000000000000000003", + "000000000000000000000004", + "000000000000000000000005", + "5a2a95f032a1c655cfe3de62", + "5a2af22b32a1c655cfe46013", + "5a2ba6de32a1c655cfe51b79", + "5a3b9731e24cd76dad1a5f1b", + "5a3ca9cb270f0e3f14d0eddb", + "5a3cb4e4270f0e3f14d12f43", + "5a03e732454a8a7ec672776c", + "5a3f4aba5889373fbbc5d3b5", + "5a5a1e48d62c7a12d5d00e47", + "5a6b1c418d100c2f8fdc4411", + "5a6feeb54a7fbc3f874f9db7", + "5a7cb1d6fe5c0d6fb53e64fb", + "5a7d3db14989e929563eb153", + "5a8aa0fab18050187cbe060e", + "5a9e5df65baeef72b4a021cd", + "5a48ba95c7dab83a7d7b44ed", + "5a48c4e9c7dab83a7d7b5cc7", + "5a48d4b2c7dab83a7d7b9851", + "5a69c47d0d5d0a7f3b2e9752", + "5a77b46b318efe6c6736e68a", + "5a355c271b63f53d5970f362", + "5a533e8034d7582116e34209", + "5a562fc7425d0f5186314725", + "5a618c72784780334bc1972d", + "5a752d42acc41e2423f17674", + "5a969eea91dfc339a9a3ad2c", + "5a8315f624b8e938486e0bd8", + "5a57542f333d180827dfc132", + "5a0271884e62597cdee0d0eb", + "5a6400933d809f1d8200af15", + "5a6464143d809f1d8208c43c", + "5a563183425d0f5186314855", + "5aa0f9d7a9efce63548c69a1", + "5aa7db90bfdd572271e95246", + "5aa235f64a17b335eeaf9609", + "5aa515e613d42d091d29d300", + "5aa1196ea9efce63548ed649", + "5aaadd4cbc13235570d178a7", + "5ab6af12ac4291329b1072ab", + "5ab7e00aac4291329b15864d", + "5ab8b8e029f5351f7f2ccf59", + "5ab74bf2ac4291329b11e879", + "5ab85f1dac4291329b17cb50", + "5ab8713ba3799a1d138bd69a", + "5abc2506b53b042ead637d86", + "5acc7459a7853c4b5ebbef59", + "5acf8ca0f3d8a750097e4b15", + "5adc6bd52430a05ecb2ffb85", + "5af02e904c8216544b4ab5a2", + "5af28cea59bc705737003253", + "5af545d0559359053d25dcf5", + "5afacb69ab00705d0cefdd5b", + "5b3b2b9e8d46a939f933fdc0", + "5b3b353d8d46a939f93524b9", + "5b6e716d67b396324c2d77cb", + "5b6eff8b67b396324c5b2672", + "5b7a3890fc8fcf6781e2593a", + "5b60fa0c764f146feef84df0", + "5b69cc0cb44b61786eb959bf", + "5b78e57afc8fcf6781d0c3ba", + "5b192eb2170cf166458ff886", + "5b558a928bbfb62204e77ba2", + "5b908d3dc6ab78485f3d24a9", + "5b950c71608de421b1e7318f", + "5b08286b2775267d5b0634ba", + "5b271079e0878c3816dacca4", + "5b22269758e2823a67a3bd03", + "5b62647143840965efc0dbde", + "5ba19a8a360c7c30c1c169df", + "5ba75d79d76ffa2c86cf2f05", + "5bb7a08aea1cfa39f1a947ab", + "5bb8a49aea1cfa39f1aa7f75", + "5bbb6eb2ea1cfa39f1af7e0c", + "5bce7ac9ca24970bce4934b6", + "5bcf979a6d5f586b95c258cd", + "5bd43b4ba6b28b1ee86b92dd", + "5be3a5fb8cfdd56947f6b67c", + "5be3ae47f44e235bdbbc9771", + "5be4ab93870d330ff2dce134", + "5be47bf9b18881428d8fbc1d", + "5be883a4f98cee15019d5b83", + "5bea87f4abd34c35e1860ab5", + "5beb6e66abd34c35e18e66b9", + "5bf3a82cd439231948877aed", + "5bf7d63575c26f32dbf7413b", + "5bf17c0fd439231948355385", + "5bf21799d43923194842c001", + "5bfd0f32ec61ca1dd69dc77b", + "5bfe5ae0fe0ea555e6a969ca", + "5c0d13b795da9479e12e2ee9", + "5c1af2e2bee9a723c963d019", + "5c1b1500bee9a723c96c3e78", + "5c1dbf200843bc542d8ef8c4", + "5c20ca3a0843bc542d94e3e2", + "5c062d84a96e33018ff6f0a6", + "5c189f2326173c3a09ed7ef3", + "5c1892f726173c3a09ea9aeb", + "5c34300a73a8df509add216d", + "000000000000000000000006", + "000000000000000000000007", + "000000000000000000000008", + "000000000000000000000009", + "000000000000000000000010", + "000000000000000000000011", + "000000000000000000000012", + "000000000000000000000015", + "000000000000000000000016", + "000000000000000000000017", + "000000000000000000000018", + "000000000000000000000019", + "56d73ba74bd29b8c35abade2", + "56f34064e296120e10484dc4", + "57a4a7bb6b9272286e26dc18", + "57f8d9bbe73f6760f10e916a", + "58a0a2f33d0b4542479a11b1", + "58a0dd1a3d0b4542479a28f3", + "58a1a7914a4d262a170b1101", + "58a1bc804a4d262a170b2f01", + "58a1d9d14a4d262a170b58fe", + "58a01dea38486e3c98475871", + "58a1f5d74a4d262a170b65fc", + "58a2a09e156b87103d3d668c", + "58a2d9c3156b87103d3da90f", + "58a3ccb0156b87103d3e4332", + "58a3f2f8156b87103d3e5838", + "58a3f6c0156b87103d3e5971", + "58a3fc95156b87103d3e5d9b", + "58a07ce53d0b45424799fdde", + "58a07f233d0b45424799ffe7", + "58a44df2156b87103d3ee239", + "58a164f73d0b4542479a7a8e", + "58a0365e38486e3c984783eb", + "58a439cf156b87103d3ec885", + "58a464aa156b87103d3eec04", + "58a4452f156b87103d3ed55b", + "58a160983d0b4542479a7347", + "58a285424a4d262a170baf3e", + "58a41819156b87103d3e92a5", + "58a47552156b87103d3f00a4", + "58c4bb4f4a69c55606122be4", + "58c6451e4a69c556061894f1", + "58ca7014affdfd07c70a95ce", + "58cf4771d0f5fb221defe6da", + "58d36897f387231e6c929903", + "58eaf1513353456af3a1682a", + "58f73e7c9f5b56478738929f", + "59a8f851597729752c31e7e0", + "59a452bf9b460239aa5d1c72", + "59a9619a825418241fb88191", + "59bf97fe7e7b31545da34439", + "59c1c3e2fd6e3d4ead9f1013", + "59d2657f82ca7774b1ec081d", + "59da1fb88a126011d0394ae9", + "59e75a2ca9e91f2c5526005d", + "59e864b2a9e91f2c5529325f", + "59ecfd02e225f6492d20fcc9", + "59f37f74b45be2233001ba18", + "59f70ab1e5c5d366af29bf3e", + "59f363a8b45be22330016cad", + "564a27b26d07883f460d8ab0", + "565fb1dead14d4154dae2b94", + "569b92eb826bcba945ca002b", + "576fefa017ce5a16397e87fd", + "584a7333fe3cb463906c9fe6", + "584aa8e9fe3cb463906cc7d0", + "584af003fe3cb463906d0e9b", + "584b9a747072670e72bfc49d", + "584b671f7072670e72bfaaf8", + "584b81747072670e72bfbbfd", + "584ba35f7072670e72bfca4d", + "584ba5977072670e72bfcc2d", + "584bc53c7072670e72bfe85f", + "584bc3997072670e72bfe58d", + "584bc4407072670e72bfe665", + "584bd5587072670e72bffe39", + "584bdadf7072670e72c0005c", + "584be5ed7072670e72c007b3", + "584c9ad27072670e72c060c5", + "584c9cc67072670e72c063a1", + "584cea557072670e72c07fb4", + "584d19d47072670e72c0c6c0", + "584dfe467072670e72c1665a", + "584e875c7072670e72c1ec94", + "584e05667072670e72c17167", + "584f94e87072670e72c2d3f7", + "584fdffd7072670e72c32dc7", + "584fe07f7072670e72c32e59", + "585a2a71b338a62ad50138dc", + "585a206ab338a62ad501298f", + "585a217cb338a62ad5012b38", + "585b34afb338a62ad501e836", + "585bb25fc49c8507c3ce7812", + "585bbe55c49c8507c3ce81cd", + "585d6c8a2a57cc11d4920a1e", + "585e54c72a57cc11d492f71a", + "585e34302a57cc11d492be30", + "585ee0632a57cc11d4933608", + "585f9661712e2761468dabca", + "585ffe9a712e2761468df643", + "586a37ec9d1b5e34c28184fc", + "586a515a9d1b5e34c281b431", + "586a94939d1b5e34c2823b5d", + "586abc689d1b5e34c2826360", + "586b0e219d1b5e34c2828862", + "586b3db89d1b5e34c282cd52", + "586b4c459d1b5e34c282e66d", + "586b7d7d9d1b5e34c283359e", + "586b8f149d1b5e34c283497c", + "586b8f629d1b5e34c28349d6", + "586c4c4d9d1b5e34c28391a1", + "586c5b5b9d1b5e34c2839a5b", + "586c9fdf9d1b5e34c283b657", + "586caab99d1b5e34c283c213", + "586cd0779d1b5e34c28403a7", + "586d6d249d1b5e34c284b80e", + "586d8a029d1b5e34c284c948", + "586d55af9d1b5e34c284a999", + "586d07869d1b5e34c2842e5b", + "586d27489d1b5e34c28453af", + "586e279c9d1b5e34c2852180", + "587bc5ec2366dd5d06e262c1", + "587c1abf2366dd5d06e28901", + "587c03f12366dd5d06e27722", + "587c19da2366dd5d06e2877b", + "587c31b92366dd5d06e2a9dc", + "587c87d02366dd5d06e2f989", + "587c97a52366dd5d06e30a96", + "587c45192366dd5d06e2c0eb", + "587cec702366dd5d06e37862", + "587cef0a2366dd5d06e379e3", + "587db5872366dd5d06e3e0af", + "587e2b1d2366dd5d06e41af0", + "587e2ea62366dd5d06e41f2e", + "587e5cb52366dd5d06e4486e", + "587eb1822366dd5d06e45f29", + "587f365d2366dd5d06e4906e", + "588a9c5fec4d5a1c088ec350", + "588a34cfec4d5a1c088ea8d1", + "588ab5bdec4d5a1c088ed60f", + "588aff9d90414422fbe7885a", + "588b20d290414422fbe79f40", + "588c08d590414422fbe8200b", + "588c203d90414422fbe8319e", + "588c989a90414422fbe86d96", + "588ca09d90414422fbe871a1", + "588cce2190414422fbe88520", + "588cd5ef90414422fbe8875c", + "588cf0ad90414422fbe8a20f", + "588e01c490414422fbe8ee2a", + "588e35e690414422fbe90a53", + "588f017e90414422fbe9b74b", + "588f095190414422fbe9c1ee", + "589aca717dc3d323d55671c4", + "589af2c97dc3d323d55691e8", + "589b49ea7dc3d323d556d9b4", + "589b04287dc3d323d556a185", + "589bf6a57dc3d323d55743ab", + "589c3c497dc3d323d5578468", + "589c3c577dc3d323d5578480", + "589c24527dc3d323d5577126", + "589c35457dc3d323d5577d8d", + "589ca6a6b896147a1b73aff7", + "589d1e1fb896147a1b73ee5b", + "589d5c58b896147a1b742256", + "589d95538fa2cf375df3317b", + "589df0ffb504a864ad63521a", + "589ea316b504a864ad639a2b", + "589ec97cb504a864ad63adc3", + "589f214338486e3c9846f123", + "589fdfe738486e3c984736cf", + "590c2d70336bb52a190be886", + "591a467a6109e14d4f09b776", + "591cf3033162411cf9047f37", + "591ea44850991c70dc99a207", + "599aa591d5b41f366fed0d58", + "5643df56138263b51db1b5f3", + "5644bdac138263b51db9f669", + "5850d4f97072670e72c425d6", + "5854c405804be105852330fe", + "5855a4fc804be1058523bd75", + "5856ac15804be105852419d8", + "5856ae8b804be10585241bae", + "5856b460804be10585242059", + "5857aa5ab338a62ad5ff4dbe", + "5857acf8b338a62ad5ff5107", + "5858db6cb338a62ad500103b", + "5858dbcab338a62ad5001081", + "5859d84fb338a62ad500e5cf", + "5861d8ea712e2761468f3cb3", + "5863edf8712e27614690cce0", + "5864b076712e27614691197e", + "5864da88712e276146913d8b", + "5865f4a8712e27614691e39b", + "5867a434833dfe3f7b88edaf", + "5868cd15833dfe3f7b89bfa3", + "5880b3692366dd5d06e5d534", + "5880e3422366dd5d06e5ff8e", + "5880f0ef2366dd5d06e6166e", + "5881d2bfb6844814c136a119", + "5881f11d8ce2c2754d0714c3", + "5881fee18ce2c2754d0723f8", + "5882cda2b116682b4adebd25", + "5882d58fb116682b4adec7db", + "5884c256932ba84fbed70bf5", + "5884cc13932ba84fbed71ec4", + "5885bc5296fa095e0671a7f0", + "5886d14cb791366d617a362c", + "5888becfc02346100f4b0b21", + "5888e408c02346100f4b1a29", + "5889da66ec4d5a1c088e5187", + "5889e754ec4d5a1c088e60ba", + "5890c16b90414422fbeb0262", + "5891d8ae9a8c0314c5cd30ab", + "5891d0479a8c0314c5cd2abd", + "5891ecf19a8c0314c5cd490a", + "5892c0cd9a8c0314c5cdc977", + "5894ab309a8c0314c5cee57d", + "5895a6a89a8c0314c5cfca7c", + "5895b8c29a8c0314c5cfd051", + "5895d38f9a8c0314c5cfe50c", + "5895f2329a8c0314c5d00117", + "5896bb989a8c0314c5d086b6", + "5896ebf39a8c0314c5d0a8c4", + "5898b1bac9dccc22987b7f74", + "5898b6ffc9dccc22987b8a03", + "5898bbaac9dccc22987b8eba", + "5899cfa6b76d7a3780a4cb64", + "5899e5dcb76d7a3780a4ecc1", + "57102be2877e1421026358af", + "57153d4031bb9900425bde85", + "57177cd7fb8d93461afc4527", + "58497cdf97b73e0b090c4273", + "58500b007072670e72c35588", + "58510bf97072670e72c46ddf", + "58522bd56789802282f2ecb3", + "58524a2e0e7012308944bcf3", + "58524a080e7012308944bcbf", + "58524c1d0e7012308944bfda", + "58524f170e7012308944c200", + "58529a4e0e70123089454c6f", + "58551bdf804be1058523556d", + "58568c9a804be10585240b03", + "58574b35804be105852455fd", + "58577c60b338a62ad5ff1564", + "58592d69b338a62ad5007a74", + "58625f42712e2761468fb44c", + "58651bcc712e2761469166dc", + "58660e79712e27614691fe3d", + "58669aad712e27614692834c", + "58676c36833dfe3f7b88b7f2", + "58678b2d833dfe3f7b88e244", + "58800b0b2366dd5d06e5312d", + "58805eac2366dd5d06e56460", + "58806e422366dd5d06e57bb6", + "58831d060db9bf59bf8ab98b", + "58851ebb932ba84fbed7abad", + "58871dc3b791366d617a55ff", + "58873cabb791366d617a65a7", + "58873d44b791366d617a65dd", + "58888b3dc02346100f4af665", + "58933bac9a8c0314c5ce3508", + "58938e6d9a8c0314c5ce726f", + "58951cb49a8c0314c5cf4d5e", + "58970fd09a8c0314c5d0e383", + "58977ef09a8c0314c5d17b26", + "59056e6760bb961de55f3501", + "59071f2e5a6dbd3af4130f98", + "59102c811225725be9e64149", + "59338e76772c3e6384afbb15", + "59350ca084b7f26bf5ce6eb8", + "59397e493a87372f2c9e882b", + "59521e0b9096412211c2aa9d", + "59817e4a1bd4b175e7038d19", + "567884f58d2828b95e3c8eba", + "585559d9804be10585238ddf", + "585834cdb338a62ad5ffab4d", + "586082d8712e2761468e2877", + "586133c2712e2761468ecfe3", + "586281d2712e2761468fcaa2", + "586316e5712e276146903c4d", + "586326ad712e276146904571", + "586375c9712e276146907429", + "586389c9712e276146908da6", + "586496fa712e2761469108e7", + "586669c6712e27614692597a", + "586913a49d1b5e34c2808b02", + "586922da9d1b5e34c2809ff3", + "588185d8dfb7a15588a114a3", + "588315c60db9bf59bf8aa928", + "588332ee0db9bf59bf8ae9c3", + "588519d5932ba84fbed7a04a", + "588824d1b791366d617adeef", + "588857f6c02346100f4ac09f", + "589145ef90414422fbeb2e08", + "589433fa9a8c0314c5ce9656", + "589765d39a8c0314c5d16b12", + "5851165f7072670e72c4860d", + "5859341ab338a62ad500848d", + "5863915b712e276146909135", + "5866445b712e27614692383e", + "5866500d712e2761469240fd", + "5867785a833dfe3f7b88c764", + "5867969c833dfe3f7b88e8bc", + "5868040c833dfe3f7b8934f7", + "5882372c8ce2c2754d076af0", + "5883535e932ba84fbed5ad07", + "5888358cb791366d617af69d", + "5890330d90414422fbeaa0cb", + "5897076e9a8c0314c5d0d31b", + "5940564ec2d9527ab869f7e2", + "5947719bf1b45630bd096665", + "5948194ff1b45630bd0f47e3", + "5950206a41b158666ac50506", + "5983012d1bd4b175e70c985a", + "58586810b338a62ad5ffc20c", + "58592046b338a62ad5006b33", + "58592854b338a62ad500750a", + "58596531b338a62ad500aace", + "58818685dfb7a15588a11626", + "58829563f42b1d3ee3ec835f", + "58894345c02346100f4b51ca", + "585289980e7012308945276a", + "585369770e7012308945c709", + "585373640e7012308945cab9", + "588230658ce2c2754d076728", + "589388059a8c0314c5ce718b", + "595979485ec6a95e86a58c8d", + "5841206219d291325678ca90", + "58563650804be1058523da55", + "58564084804be1058523e116", + "58636467712e27614690661f", + "58647495712e27614690f36d", + "58654563712e276146918643", + "58664251712e276146923738", + "588084032366dd5d06e59e82", + "588159582366dd5d06e66877", + "5890279190414422fbea9734", + "5890641690414422fbeabbe7", + "585203546789802282f2aaf5", + ] + + # Validation set sequences after filtering + self.val_split_scenes = [ + "00000000000000000000000a", + "5a4a38dad38c8a075495b5d2", + "5a489fb1c7dab83a7d7b1070", + "5a572fd9fc597b0478a81d14", + "5a588a8193ac3d233f77fbca", + "5aa0f478a9efce63548c1cb4", + "5ae2e9c5fe405c5076abc6b2", + "5b2c67b5e0878c381608b8d8", + "5b21e18c58e2823a67a10dd8", + "5b864d850d072a699b32f4ae", + "5b4933abf2b5f44e95de482a", + "5b37189a35304b6f75e7583e", + "5bc5f0e896b66a2cd8f9bd36", + "5bccd6beca24970bce448134", + "5bf26cbbd43923194854b270", + "5bf18642c50e6f7f8bdbd492", + "5bfc9d5aec61ca1dd69132a2", + "5bff3c5cfe0ea555e6bcbf3a", + "5c1f33f1d33e1f2e4aa6dda4", + "5c34529873a8df509ae57b58", + "58a186444a4d262a170ae3ae", + "58f7f7299f5b5647873cb110", + "59acd2f4b891807f439c8992", + "567a0fb0a825d2fb79ac9a20", + "584ad76bfe3cb463906ce6dc", + "584c58b77072670e72c03990", + "586c48329d1b5e34c2838e80", + "586df9849d1b5e34c28506de", + "588e0d8c90414422fbe8f8b2", + "589c300f7dc3d323d5577926", + "590f91851225725be9e25d4e", + "5889e344ec4d5a1c088e59be", + "5898b31cc9dccc22987b82ec", + "5947b62af1b45630bd0c2a02", + "58598db2b338a62ad500bc38", + "58669c02712e27614692851a", + "58790c82ce911104a3467c88", + "58897f62c02346100f4b8ee6", + "588305ed0db9bf59bf8a8c80", + "588457b8932ba84fbed69942", + "5862388b712e2761468f84aa", + "5880675a2366dd5d06e570ca", + "5890523090414422fbeab3f0", + ] + + +class TartanAirV2Splits: + """ + This class contains the information about the splits of the TartanAir V2 dataset. + """ + + def __init__(self): + """ + Splits of environments with unique geometry selected based on TartanVO & UFM splits. + """ + # Apart from the below 2 splits, all other TAv2 scenes are in the train split + # Val split + self.val_split_scenes = ["EndofTheWorld", "HongKong", "WesternDesertTown"] + + # Test split + self.test_split_scenes = [ + "DesertGasStation", + "OldScandinavia", + "PolarSciFi", + "Sewerage", + "Supermarket", + ] + + +class MegaDepthSplits: + """ + This class contains the information about the splits of the MegaDepth dataset. + """ + + def __init__(self): + """ + Validation split is based on scenes used in DUSt3R. + """ + self.val_split_scenes = ["0015_0", "0015_1", "0022_0"] + + +class SpringSplits: + """ + This class contains the information about the splits of the Spring dataset. + """ + + def __init__(self): + self.val_split_scenes = ["0013", "0023", "0037"] + + +class MPSDSplits: + """ + This class contains the information about the splits of the MPSD dataset. + """ + + def __init__(self): + """ + Train & Validation split numpy files containing folder names are generated during preprocessing of MPSD dataset. + Load the numpy files to get the list of scenes in the train & validation split. + A 95% (Train) & 5% (Validation) split is used. + """ + self.train_split_scenes = "load_numpy_file_with_train_scenes" + self.val_split_scenes = "load_numpy_file_with_val_scenes" + + +class ScanNetPPSplits: + """ + This class contains the information about the splits of the ScanNetPP dataset. + """ + + def __init__(self): + """ + Validation & Test split only contains scenes from ScanNet++V2 to prevent data leak with other methods such as DUSt3R during benchmarking. + + Following logic was used to generate the splits: + # Select 80%, 10%, 10% of the scenes for train, val, test respectively from ScanNet++ V2 (~300 scene subset; excluding V1 scenes) + snpp_v2_test_scenes = np.random.choice( + snpp_v2_processed_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False + ) + remaining_scenes = [scene for scene in snpp_v2_processed_scenes if scene not in snpp_v2_test_scenes] + snpp_v2_val_scenes = np.random.choice( + remaining_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False + ) + snpp_v2_train_scenes = [ + scene for scene in remaining_scenes if scene not in snpp_v2_val_scenes and scene not in snpp_v2_test_scenes + ] + """ + # Validation Scenes + self.val_split_scenes = [ + "1c7a683c92", + "2a1b555966", + "3a43c7b8d2", + "4aef651da7", + "06bc6d1b24", + "7f22d5ef1b", + "7f77abce34", + "8ea517a2fc", + "29c7afafed", + "41eb967018", + "77b40ce601", + "086f09d6e3", + "307e3262f1", + "639f2c4d5a", + "894dbd41f1", + "898a7dfd0c", + "2779f8f9e2", + "151178afd7", + "182932a4f3", + "635852d56e", + "9906136b57", + "af112b8903", + "b0f057c684", + "b37177e6c8", + "b119249da7", + "be8367fcbe", + "c8fc01c453", + "e1fb8626c8", + "e2caaaf5b5", + "fe3fc057a1", + ] + + # Test Scenes + self.test_split_scenes = [ + "0e900bcc5c", + "0eba3981c9", + "1cbb105c6a", + "3c8d535d49", + "5d902f1593", + "6bd39ac392", + "6c14d5fd01", + "7c31a42404", + "9bfbc75700", + "13b4efaf62", + "062e5a23a6", + "95b9971d01", + "246fe09e98", + "637a27d04b", + "725b8f0cba", + "413085a827", + "696317583f", + "a4c043ac48", + "a9e4791c7e", + "b0b004c40f", + "c3bc5e82c5", + "c31ebd4b22", + "cba701332a", + "cc5ea8026c", + "cec8312f4e", + "e3b3b0d0c7", + "e667e09fe6", + "eaa6c90310", + "f9397af4cb", + "fb893ffaf3", + ] + + +class DL3DV10KSplits: + """ + This class contains the information about the splits of the DL3DV-10K dataset. + We use the official benchmark split as the val split. + """ + + def __init__(self): + """ + Validation split is based on DL3DV-Benchmark. + """ + self.val_split_scenes = [ + "load https://huggingface.co/datasets/DL3DV/DL3DV-Benchmark/raw/main/benchmark-meta.csv \ + & https://raw.githubusercontent.com/DL3DV-10K/Dataset/main/cache/DL3DV-valid.csv" + ] + + +class ETH3DSplits: + """ + This class contains the information about the splits of the ETH3D dataset. + """ + + def __init__(self): + """ + All scenes are in the test split. + """ + self.test_split_scenes = "all" diff --git a/mapanything/datasets/wai/__init__.py b/mapanything/datasets/wai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/datasets/wai/ase.py b/mapanything/datasets/wai/ase.py new file mode 100644 index 0000000000000000000000000000000000000000..b55249783129fab35ceda619e9b55639eec18b87 --- /dev/null +++ b/mapanything/datasets/wai/ase.py @@ -0,0 +1,294 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +ASE Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class ASEWAI(BaseDataset): + """ + ASE dataset containing large diversity of synthetic indoor scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"ase_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Resize the data to match the desired resolution + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=None, + ) + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="ASE", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/ase", type=str) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = ASEWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 518), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = ASEWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 518), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "ASE_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/blendedmvs.py b/mapanything/datasets/wai/blendedmvs.py new file mode 100644 index 0000000000000000000000000000000000000000..9f133652fe194bb654fc44dd62633d063cd0092e --- /dev/null +++ b/mapanything/datasets/wai/blendedmvs.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +BlendedMVS Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class BlendedMVSWAI(BaseDataset): + """ + BlendedMVS dataset containing object-centric and birds-eye-view scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = False + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"blendedmvs_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth", "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="BlendedMVS", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/blendedmvs", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = BlendedMVSWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 392), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = BlendedMVSWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 392), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "BlendedMVS_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=10, replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/dl3dv.py b/mapanything/datasets/wai/dl3dv.py new file mode 100644 index 0000000000000000000000000000000000000000..166a79a4b6df0e9b60b25e92323dd537e37c3297 --- /dev/null +++ b/mapanything/datasets/wai/dl3dv.py @@ -0,0 +1,356 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +DL3DV Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.cropping import ( + rescale_image_and_other_optional_info, + resize_with_nearest_interpolation_to_match_aspect_ratio, +) +from mapanything.utils.wai.core import load_data, load_frame + + +class DL3DVWAI(BaseDataset): + """ + DL3DV dataset containing over 10k in-the-wild and indoor scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + mvs_confidence_filter_thres: float = 0.25, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + mvs_confidence_filter_thres: Confidence threshold to filter MVS depth. Defaults to 0.25. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self.mvs_confidence_filter_thres = mvs_confidence_filter_thres + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = False + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"dl3dv_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0_mvsa_based" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=[ + "image", + "pred_depth/mvsanywhere", + "pred_mask/moge2", + "depth_confidence/mvsanywhere", + ], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["pred_depth/mvsanywhere"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the dimensions of the original image + img_h, img_w = image.shape[:2] + + # Resize depth to match image aspect ratio while ensuring that depth resolution doesn't increase + depthmap, target_depth_h, target_depth_w = ( + resize_with_nearest_interpolation_to_match_aspect_ratio( + input_data=depthmap, img_h=img_h, img_w=img_w + ) + ) + + # Now resize the image and update intrinsics to match the resized depth + image, _, intrinsics, _ = rescale_image_and_other_optional_info( + image=image, + output_resolution=(target_depth_w, target_depth_h), + depthmap=None, + camera_intrinsics=intrinsics, + ) + image = np.array(image) + + # Get the depth confidence map and mask out the MVS depth + confidence_map = view_data["depth_confidence/mvsanywhere"].numpy() + confidence_mask = ( + confidence_map > self.mvs_confidence_filter_thres + ).astype(int) + confidence_mask = cv2.resize( + confidence_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + depthmap = np.where(confidence_mask, depthmap, 0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="DL3DV", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/dl3dv", type=str) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = DL3DVWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + mvs_confidence_filter_thres=0.25, + resolution=(518, 294), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = DL3DVWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # mvs_confidence_filter_thres=0.25, + # resolution=(518, 294), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "DL3DV_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/dynamicreplica.py b/mapanything/datasets/wai/dynamicreplica.py new file mode 100644 index 0000000000000000000000000000000000000000..148ff2315f63daaa63f10bb6e5a5d501f5159df2 --- /dev/null +++ b/mapanything/datasets/wai/dynamicreplica.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Dynamic Replica Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class DynamicReplicaWAI(BaseDataset): + """ + Dynamic Replica dataset containing synthetic scenes with humans and animals. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"dynamicreplica_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = image[:, :, :3] # RGBA to RGB + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Resize the data to match the desired resolution + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=None, + ) + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="DynamicReplica", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/dynamicreplica", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = DynamicReplicaWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 294), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = DynamicReplicaWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 294), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "DynamicReplica_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/eth3d.py b/mapanything/datasets/wai/eth3d.py new file mode 100644 index 0000000000000000000000000000000000000000..923b98ab6472aa02833d14496edf2ee24146f442 --- /dev/null +++ b/mapanything/datasets/wai/eth3d.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +ETH3D Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class ETH3DWAI(BaseDataset): + """ + ETH3D dataset containing high-quality outdoor and indoor scans of the ETH Zurich campus. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = "test" + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"eth3d_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Resize the data to match the desired resolution + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=None, + ) + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="ETH3D", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/eth3d", type=str) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = ETH3DWAI( + num_views=args.num_of_views, + covisibility_thres=0.025, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 336), + seed=777, + transform="imgnorm", + data_norm_type="dinov2", + ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "ETH3D_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/megadepth.py b/mapanything/datasets/wai/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcb9a8f98b0439ec0a6aec311e6cf2cc251acb0 --- /dev/null +++ b/mapanything/datasets/wai/megadepth.py @@ -0,0 +1,314 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MegaDepth Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class MegaDepthWAI(BaseDataset): + """ + MegaDepth dataset containing outdoor phototourism and in-the-wild scenes. + Also includes Tanks & Temples scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = False + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"megadepth_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth", "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="MegaDepth", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/megadepth", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = MegaDepthWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 336), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = MegaDepthWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 336), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "MegaDepth_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/mpsd.py b/mapanything/datasets/wai/mpsd.py new file mode 100644 index 0000000000000000000000000000000000000000..ee6f48c87853684fe7ccabc6b8747fbcdd40cc18 --- /dev/null +++ b/mapanything/datasets/wai/mpsd.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MPSD Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class MPSDWAI(BaseDataset): + """ + MPSD dataset containing outdoor planet scale metric reconstructions. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"mpsd_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth", "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="MPSD", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/mpsd", type=str) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = MPSDWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.15, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 392), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = MPSDWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.15, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 392), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "MPSD_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/mvs_synth.py b/mapanything/datasets/wai/mvs_synth.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1006e9141541741f4d0f7f7c7f8f54328e33c8 --- /dev/null +++ b/mapanything/datasets/wai/mvs_synth.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MVS Synth Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class MVSSynthWAI(BaseDataset): + """ + MVS Synth dataset containing large diversity of synthetic in-the-wild scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"mvs_synth_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) + non_ambiguous_mask = (depthmap > 0).astype(int) + + # Mask out the outlier depth (horizon depth) + percentile_depth = np.percentile(depthmap, 95) + depthmap[depthmap > percentile_depth] = 0 + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="MVSSynth", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/mvs_synth", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = MVSSynthWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 294), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = MVSSynthWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 294), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "MVSSynth_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/paralleldomain4d.py b/mapanything/datasets/wai/paralleldomain4d.py new file mode 100644 index 0000000000000000000000000000000000000000..f7900ab702508e63d7e89585bdf4b4b179411a57 --- /dev/null +++ b/mapanything/datasets/wai/paralleldomain4d.py @@ -0,0 +1,309 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Parallel Domain 4D Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class ParallelDomain4DWAI(BaseDataset): + """ + Parallel Domain 4D dataset containing large diversity of synthetic AV scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"paralleldomain4d_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = image[:, :, :3] # RGBA to RGB + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) + non_ambiguous_mask = (depthmap > 0).astype(int) + + # Mask out the outlier depth (horizon depth) + percentile_depth = np.percentile(depthmap, 95) + depthmap[depthmap > percentile_depth] = 0 + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="ParallelDomain4D", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/paralleldomain4d", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = ParallelDomain4DWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 392), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = ParallelDomain4DWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 392), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "ParallelDomain4D_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/sailvos3d.py b/mapanything/datasets/wai/sailvos3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e59a2068ee32a44fbdbc9297afebb354f02c2c52 --- /dev/null +++ b/mapanything/datasets/wai/sailvos3d.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +SAIL-VOS 3D Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class SAILVOS3DWAI(BaseDataset): + """ + SAIL-VOS 3D dataset containing large diversity of synthetic in-the-wild cut scenes from GTA. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"sailvos3d_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) + non_ambiguous_mask = (depthmap > 0).astype(int) + + # Mask out the outlier depth (horizon depth) + percentile_depth = np.percentile(depthmap, 95) + depthmap[depthmap > percentile_depth] = 0 + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="SAILVOS3D", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/sailvos3d", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = SAILVOS3DWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 336), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = SAILVOS3DWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 336), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "SAILVOS3D_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/scannetpp.py b/mapanything/datasets/wai/scannetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..1367b461a8456b2d3a73cf2fd4749771e839f5d4 --- /dev/null +++ b/mapanything/datasets/wai/scannetpp.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +ScanNet++V2 Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class ScanNetPPWAI(BaseDataset): + """ + ScanNet++V2 dataset containing large diversity of indoor scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"scannetppv2_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "rendered_depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["rendered_depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Resize the data to match the desired resolution + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=None, + ) + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="ScanNetPP", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/scannetppv2", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = ScanNetPPWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 336), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = ScanNetPPWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 336), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + # dataset = ScanNetPPWAI( + # num_views=args.num_of_views, + # split="test", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 336), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "ScanNetPP_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=10, replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/spring.py b/mapanything/datasets/wai/spring.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b67f2a611c6f293f9619d989ce7ef3f7806649 --- /dev/null +++ b/mapanything/datasets/wai/spring.py @@ -0,0 +1,316 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Spring Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class SpringWAI(BaseDataset): + """ + Spring dataset containing high-quality large-scale in-the-wild scenes with unique animated objects. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"spring_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) # Assumes only npy file in directory is covisibility map + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth", "skymask", "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Get the sky mask and mask out GT depth + sky_mask = view_data["skymask"].numpy().astype(int) + depthmap = np.where(sky_mask, 0, depthmap) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="Spring", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/spring", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = SpringWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 294), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = SpringWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 294), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "Spring_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=10, replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/tav2_wb.py b/mapanything/datasets/wai/tav2_wb.py new file mode 100644 index 0000000000000000000000000000000000000000..9446139fb9188544f4dcd3480e43a2c8c8906c94 --- /dev/null +++ b/mapanything/datasets/wai/tav2_wb.py @@ -0,0 +1,328 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +TartanAirV2-WB Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class TartanAirV2WBWAI(BaseDataset): + """ + TartanAirV2-WB dataset containing vastly-sized in-the-wild synthetic scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"tav2_wb_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth", "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Mask out the outlier depth caused due to transparent windows in TartanAirV2 + percentile_depth = np.percentile(depthmap, 95) + depthmap[depthmap > percentile_depth] = 0 + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="TartanAirV2WB", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/tav2_wb", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = TartanAirV2WBWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 518), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = TartanAirV2WBWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 518), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + # dataset = TartanAirV2WBWAI( + # num_views=args.num_of_views, + # split="test", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 518), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "TartanAirV2WB_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/datasets/wai/unrealstereo4k.py b/mapanything/datasets/wai/unrealstereo4k.py new file mode 100644 index 0000000000000000000000000000000000000000..f16f9b226d15ebc4273dc824e0c5885c636ea365 --- /dev/null +++ b/mapanything/datasets/wai/unrealstereo4k.py @@ -0,0 +1,309 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +UnrealStereo4K Dataset using WAI format data. +""" + +import os + +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class UnrealStereo4KWAI(BaseDataset): + """ + UnrealStereo4K dataset containing synthetic in-the-wild scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = True + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"unrealstereo4k_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisibility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = image[:, :, :3] # RGBA to RGB + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) + non_ambiguous_mask = (depthmap > 0).astype(int) + + # Mask out the outlier depth (horizon depth) + percentile_depth = np.percentile(depthmap, 95) + depthmap[depthmap > percentile_depth] = 0 + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="UnrealStereo4K", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/unrealstereo4k", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = UnrealStereo4KWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 294), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = UnrealStereo4KWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 294), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "UnrealStereo4K_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/mapanything/models/__init__.py b/mapanything/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..498d19c825297a091df9e067cae63c2de68587ef --- /dev/null +++ b/mapanything/models/__init__.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Model Factory for MapAnything +""" + +import importlib.util +import logging +import warnings + +import numpy as np +from omegaconf import DictConfig, OmegaConf + +# Core models that are always available +from mapanything.models.mapanything import ( + MapAnything, + MapAnythingAblations, + ModularDUSt3R, +) + +# Suppress DINOv2 warnings +logging.getLogger("dinov2").setLevel(logging.WARNING) +warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning) +warnings.filterwarnings( + "ignore", message="xFormers is not available", category=UserWarning +) + + +def resolve_special_float(value): + if value == "inf": + return np.inf + elif value == "-inf": + return -np.inf + else: + raise ValueError(f"Unknown special float value: {value}") + + +def init_model( + model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False +): + """ + Initialize a model using OmegaConf configuration. + + Args: + model_str (str): Name of the model class to create. + model_config (DictConfig): OmegaConf model configuration. + torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub. + """ + if not OmegaConf.has_resolver("special_float"): + OmegaConf.register_new_resolver("special_float", resolve_special_float) + model_dict = OmegaConf.to_container(model_config, resolve=True) + model = model_factory( + model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict + ) + + return model + + +# Define model configurations with import paths +MODEL_CONFIGS = { + # Core models + "mapanything": { + "class": MapAnything, + }, + "mapanything_ablations": { + "class": MapAnythingAblations, + }, + "modular_dust3r": { + "class": ModularDUSt3R, + }, + # External models + "anycalib": { + "module": "mapanything.models.external.anycalib", + "class_name": "AnyCalibWrapper", + }, + "dust3r": { + "module": "mapanything.models.external.dust3r", + "class_name": "DUSt3RBAWrapper", + }, + "mast3r": { + "module": "mapanything.models.external.mast3r", + "class_name": "MASt3RSGAWrapper", + }, + "moge": { + "module": "mapanything.models.external.moge", + "class_name": "MoGeWrapper", + }, + "must3r": { + "module": "mapanything.models.external.must3r", + "class_name": "MUSt3RWrapper", + }, + "pi3": { + "module": "mapanything.models.external.pi3", + "class_name": "Pi3Wrapper", + }, + "pow3r": { + "module": "mapanything.models.external.pow3r", + "class_name": "Pow3RWrapper", + }, + "pow3r_ba": { + "module": "mapanything.models.external.pow3r", + "class_name": "Pow3RBAWrapper", + }, + "vggt": { + "module": "mapanything.models.external.vggt", + "class_name": "VGGTWrapper", + }, + # Add other model classes here +} + + +def check_module_exists(module_path): + """ + Check if a module can be imported without actually importing it. + + Args: + module_path (str): The path to the module to check. + + Returns: + bool: True if the module can be imported, False otherwise. + """ + return importlib.util.find_spec(module_path) is not None + + +def model_factory(model_str: str, **kwargs): + """ + Model factory for MapAnything. + + Args: + model_str (str): Name of the model to create. + **kwargs: Additional keyword arguments to pass to the model constructor. + + Returns: + nn.Module: An instance of the specified model. + """ + if model_str not in MODEL_CONFIGS: + raise ValueError( + f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}" + ) + + model_config = MODEL_CONFIGS[model_str] + + # Handle core models directly + if "class" in model_config: + model_class = model_config["class"] + # Handle external models with dynamic imports + elif "module" in model_config: + module_path = model_config["module"] + class_name = model_config["class_name"] + + # Check if the module can be imported + if not check_module_exists(module_path): + raise ImportError( + f"Model '{model_str}' requires module '{module_path}' which is not installed. " + f"Please install the corresponding submodule or package." + ) + + # Dynamically import the module and get the class + try: + module = importlib.import_module(module_path) + model_class = getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ImportError( + f"Failed to import {class_name} from {module_path}: {str(e)}" + ) + else: + raise ValueError(f"Invalid model configuration for {model_str}") + + print(f"Initializing {model_class} with kwargs: {kwargs}") + if model_str != "org_dust3r": + return model_class(**kwargs) + else: + eval_str = kwargs.get("model_eval_str", None) + return eval(eval_str) + + +def get_available_models() -> list: + """ + Get a list of available models in MapAnything. + + Returns: + list: A list of available model names. + """ + return list(MODEL_CONFIGS.keys()) + + +__all__ = ["model_factory", "get_available_models"] diff --git a/mapanything/models/__pycache__/__init__.cpython-312.pyc b/mapanything/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b74f127c03480cea652fc8e052daf0be46f9738 Binary files /dev/null and b/mapanything/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/mapanything/models/external/README.md b/mapanything/models/external/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3fab6d3b5b2ab74546420402e2b41a9eb33def92 --- /dev/null +++ b/mapanything/models/external/README.md @@ -0,0 +1,5 @@ +# External Model Code for Benchmarking & Re-Training + +This directory contains external model code that we use to train and benchmark external models fairly. These libraries are not part of the core MapAnything codebase and are included for only benchmarking purposes. The code in this directory is licensed under the same license as the source code from which it was derived, unless otherwise specified. + +The open-source Apache 2.0 License of MapAnything does not apply to these libraries. diff --git a/mapanything/models/external/__init__.py b/mapanything/models/external/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/anycalib/__init__.py b/mapanything/models/external/anycalib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d1507bd374fe968d000ea4534f5d82f0b895478d --- /dev/null +++ b/mapanything/models/external/anycalib/__init__.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for AnyCalib +""" + +import torch +from anycalib import AnyCalib + +from mapanything.utils.geometry import get_rays_in_camera_frame + + +class AnyCalibWrapper(torch.nn.Module): + def __init__( + self, + name, + model_id="anycalib_pinhole", + **kwargs, + ): + super().__init__() + self.name = name + self.model_id = model_id + + # Initialize the model + self.model = AnyCalib(model_id=self.model_id) + + def forward(self, views): + """ + Forward pass wrapper for AnyCalib. + + Assumption: + - The number of input views is 1. + - The output camera model is pinhole (fx, fy, cx, cy). + This can be relaxed by not hardcoding the cam_id. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Length of the list should be 1. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["identity"] + + Returns: + List[dict]: A list containing the final outputs for the single view. Length of the list will be 1. + """ + # Check that the number of input views is 1 + assert len(views) == 1, "AnyCalib only supports 1 input view." + + # Get input shape of the images and batch size per view + _, _, height, width = views[0]["img"].shape + + # Check the data norm type + # AnyCalib expects a normalized image but without the DINOv2 mean and std applied ("identity") + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "identity", ( + "AnyCalib expects a normalized image but without the DINOv2 mean and std applied" + ) + + # Run AnyCalib inference + # Corresponding batched output dictionary: + # { + # "intrinsics": List[(D_i,) tensors] for each camera model "i" at the original input resolution, + # "fov_field": (B, N, 2) tensor with the regressed FoV field by the network. N≈320^2 (resolution close to the one seen during training), + # "tangent_coords": alias for "fov_field", + # "rays": (B, N, 3) tensor with the corresponding (via the exponential map) ray directions in the camera frame (x right, y down, z forward), + # "pred_size": (H, W) tuple with the image size used by the network. It can be used e.g. for resizing the FoV/ray fields to the original image size. + # } + # For "pinhole" camera model, the intrinsics are (fx, fy, cx, cy). + model_outputs = self.model.predict(views[0]["img"], cam_id="pinhole") + + # Convert the list of intrinsics to a tensor + intrinsics = [] + for intrinsics_per_sample in model_outputs["intrinsics"]: + pred_fx, pred_fy, pred_cx, pred_cy = intrinsics_per_sample + intrinsics_per_sample = torch.tensor( + [ + [pred_fx, 0, pred_cx], + [0, pred_fy, pred_cy], + [0, 0, 1], + ], + device=views[0]["img"].device, + ) + intrinsics.append(intrinsics_per_sample) + + # Convert the list of intrinsics to a tensor of size (batch_size_per_view, 3, 3) + intrinsics = torch.stack(intrinsics) + + # Get the ray directions + with torch.autocast("cuda", enabled=False): + _, ray_directions = get_rays_in_camera_frame( + intrinsics, height, width, normalize_to_unit_sphere=True + ) + + # Return the output in MapAnything format + res = [{"ray_directions": ray_directions, "intrinsics": intrinsics}] + + return res diff --git a/mapanything/models/external/dinov2/__init__.py b/mapanything/models/external/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd11e735b9c3a2ae2b1475ca163537e274fae76 --- /dev/null +++ b/mapanything/models/external/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/mapanything/models/external/dinov2/hub/__init__.py b/mapanything/models/external/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40afb43678d1db842a67445d79260c338a1c1ab5 --- /dev/null +++ b/mapanything/models/external/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/mapanything/models/external/dinov2/hub/backbones.py b/mapanything/models/external/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..3445f144304801154558f457ab041fd39be67743 --- /dev/null +++ b/mapanything/models/external/dinov2/hub/backbones.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from mapanything.models.external.dinov2.hub.utils import ( + _DINOV2_BASE_URL, + _make_dinov2_model_name, +) + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name( + arch_name, patch_size, num_register_tokens + ) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitb14( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitl14( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitg14( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg( + *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs +): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/mapanything/models/external/dinov2/hub/utils.py b/mapanything/models/external/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..faf964ecf7a76fc0186d073725c50d63c602aef6 --- /dev/null +++ b/mapanything/models/external/dinov2/hub/utils.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name( + arch_name: str, patch_size: int, num_register_tokens: int = 0 +) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list( + itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]) + ) + output = F.pad(x, pads) + return output diff --git a/mapanything/models/external/dinov2/layers/__init__.py b/mapanything/models/external/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eca24b7215c4bfa9e333d69369f0f69565ca1a2a --- /dev/null +++ b/mapanything/models/external/dinov2/layers/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mapanything.models.external.dinov2.layers.dino_head import DINOHead # noqa +from mapanything.models.external.dinov2.layers.mlp import Mlp # noqa +from mapanything.models.external.dinov2.layers.patch_embed import PatchEmbed # noqa +from mapanything.models.external.dinov2.layers.swiglu_ffn import ( + SwiGLUFFN, # noqa + SwiGLUFFNFused, # noqa +) +from mapanything.models.external.dinov2.layers.block import NestedTensorBlock # noqa +from mapanything.models.external.dinov2.layers.attention import MemEffAttention # noqa diff --git a/mapanything/models/external/dinov2/layers/attention.py b/mapanything/models/external/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b1beccc7a9f00f823957d3f4e6d51040666f0fc9 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/attention.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os + +from torch import nn, Tensor + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/mapanything/models/external/dinov2/layers/block.py b/mapanything/models/external/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..723f6868d3aa9eb84aacf8695071f27b27d49494 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/block.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Any, Callable, Dict, List, Tuple + +import torch +from torch import nn, Tensor + +from mapanything.models.external.dinov2.layers.attention import ( + Attention, + MemEffAttention, +) +from mapanything.models.external.dinov2.layers.drop_path import DropPath +from mapanything.models.external.dinov2.layers.layer_scale import LayerScale +from mapanything.models.external.dinov2.layers.mlp import Mlp + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, index_select_cat, scaled_index_add + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + x_plus_residual = scaled_index_add( + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/mapanything/models/external/dinov2/layers/dino_head.py b/mapanything/models/external/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..01b9823b0e5db9950e5f2b8dfc9e3439f7da2bf2 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/dino_head.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp( + nlayers, + in_dim, + bottleneck_dim, + hidden_dim=hidden_dim, + use_bn=use_bn, + bias=mlp_bias, + ) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp( + nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True +): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/mapanything/models/external/dinov2/layers/drop_path.py b/mapanything/models/external/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..04cb47af065ec3cfdbc8f59854efaa976ea717d5 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/drop_path.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/mapanything/models/external/dinov2/layers/layer_scale.py b/mapanything/models/external/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..b32da3bd74a028795f5ee628d2636a45b756b074 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/layer_scale.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import nn, Tensor + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/mapanything/models/external/dinov2/layers/mlp.py b/mapanything/models/external/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6d19f53d8562d07d559fe6db93d45b8286420720 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import nn, Tensor + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/mapanything/models/external/dinov2/layers/patch_embed.py b/mapanything/models/external/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..493774d038c9ee7f0f63b05f80561fa61321e2b7 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/patch_embed.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +import torch.nn as nn +from torch import Tensor + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, ( + f"Input image height {H} is not a multiple of patch height {patch_H}" + ) + assert W % patch_W == 0, ( + f"Input image width {W} is not a multiple of patch width: {patch_W}" + ) + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/mapanything/models/external/dinov2/layers/swiglu_ffn.py b/mapanything/models/external/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..c91ef5238dd7b0da9c3bca1ffe9a4f67e81ce224 --- /dev/null +++ b/mapanything/models/external/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional + +import torch.nn.functional as F +from torch import nn, Tensor + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/mapanything/models/external/dinov2/models/__init__.py b/mapanything/models/external/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87f48dae91bb29ef0d256487fda2a16652dd2a3d --- /dev/null +++ b/mapanything/models/external/dinov2/models/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +import mapanything.models.external.dinov2.models.vision_transformer as vits + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model( + cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size + ) diff --git a/mapanything/models/external/dinov2/models/vision_transformer.py b/mapanything/models/external/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..32f7032990f214946e95aaae43d4fd6a86b1c982 --- /dev/null +++ b/mapanything/models/external/dinov2/models/vision_transformer.py @@ -0,0 +1,448 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import math +from functools import partial +from typing import Callable, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.utils.checkpoint import checkpoint + +from mapanything.models.external.dinov2.layers import ( + MemEffAttention, + Mlp, + NestedTensorBlock as Block, + PatchEmbed, + SwiGLUFFNFused, +) +from mapanything.models.external.pi3.layers.attention import FlashAttention + +# logger = logging.getLogger("dinov2") + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if ffn_layer == "mlp": + # logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + # logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + # logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + attn_class=FlashAttention, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append( + [nn.Identity()] * i + blocks_list[i : i + chunksize] + ) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [ + self.prepare_tokens_with_masks(x, masks) + for x, masks in zip(x_list, masks_list) + ] + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), ( + f"only {len(output)} / {len(blocks_to_take)} blocks found" + ) + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), ( + f"only {len(output)} / {len(blocks_to_take)} blocks found" + ) + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/mapanything/models/external/dinov2/utils/__init__.py b/mapanything/models/external/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40afb43678d1db842a67445d79260c338a1c1ab5 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/mapanything/models/external/dinov2/utils/cluster.py b/mapanything/models/external/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..502a7552612fe3f4146cec43f197c16be6c54365 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/cluster.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type( + cluster_type: Optional[ClusterType] = None, +) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path( + cluster_type: Optional[ClusterType] = None, +) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, + num_gpus_per_node: int, + cluster_type: Optional[ClusterType] = None, + **kwargs, +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/mapanything/models/external/dinov2/utils/config.py b/mapanything/models/external/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..87cbfd5848a165b82f3756b7ad8fff71ed4d2366 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/config.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import math +import os + +import dinov2.distributed as distributed +from dinov2.configs import dinov2_default_config +from dinov2.logging import setup_logging +from dinov2.utils import utils +from omegaconf import OmegaConf + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt( + cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0 + ) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info( + "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())) + ) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/mapanything/models/external/dinov2/utils/dtype.py b/mapanything/models/external/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..5a91f3f084ba6e1c1db19708cddf7e4389a87c01 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/dtype.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), ( + f"Expected an instance of nunpy dtype, got {type(dtype)}" + ) + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/mapanything/models/external/dinov2/utils/param_groups.py b/mapanything/models/external/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..df3df118869b3e09214951f0bad4fed32ae79bc7 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/param_groups.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +from collections import defaultdict + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate( + name, + lr_decay_rate=1.0, + num_layers=12, + force_is_backbone=False, + chunked_blocks=False, +): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, + lr_decay_rate, + num_layers=n_blocks, + force_is_backbone=n_blocks > 0, + chunked_blocks=chunked_blocks, + ) + d = { + "params": param, + "is_last_layer": False, + "lr_multiplier": decay_rate, + "wd_multiplier": 1.0, + "name": name, + } + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info( + f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""" + ) + + return all_param_groups + + +def fuse_params_groups( + all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer") +): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/mapanything/models/external/dinov2/utils/utils.py b/mapanything/models/external/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6afe42656020daf57db783da8d76363b4b8e72a2 --- /dev/null +++ b/mapanything/models/external/dinov2/utils/utils.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + +# logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url( + pretrained_weights, map_location="cpu" + ) + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + # logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + _ = model.load_state_dict(state_dict, strict=False) + # logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__( + self, + base_value, + final_value, + total_iters, + warmup_iters=0, + start_warmup_value=0, + freeze_iters=0, + ): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * ( + 1 + np.cos(np.pi * iters / len(iters)) + ) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/mapanything/models/external/dust3r/__init__.py b/mapanything/models/external/dust3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bb39ddbcbeff5deefaefe69c570353e8e4bdb6 --- /dev/null +++ b/mapanything/models/external/dust3r/__init__.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for DUSt3R +""" + +import warnings + +import torch +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.image_pairs import make_pairs +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo # noqa + +from mapanything.models.external.vggt.utils.rotation import mat_to_quat +from mapanything.utils.geometry import ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + convert_z_depth_to_depth_along_ray, + depthmap_to_camera_frame, + get_rays_in_camera_frame, +) + +inf = float("inf") + + +def load_model(model_path, device, verbose=True): + if verbose: + print("Loading model from", model_path) + ckpt = torch.load(model_path, map_location="cpu", weights_only=False) + args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") + if "landscape_only" not in args: + args = args[:-1] + ", landscape_only=False)" + else: + args = args.replace(" ", "").replace( + "landscape_only=True", "landscape_only=False" + ) + assert "landscape_only=False" in args + if verbose: + print(f"Instantiating: {args}") + try: + net = eval(args) + except NameError: + net = AsymmetricCroCo3DStereo( + enc_depth=24, + dec_depth=12, + enc_embed_dim=1024, + dec_embed_dim=768, + enc_num_heads=16, + dec_num_heads=12, + pos_embed="RoPE100", + patch_embed_cls="PatchEmbedDust3R", + img_size=(512, 512), + head_type="dpt", + output_mode="pts3d", + depth_mode=("exp", -inf, inf), + conf_mode=("exp", 1, inf), + landscape_only=False, + ) + s = net.load_state_dict(ckpt["model"], strict=False) + if verbose: + print(s) + return net.to(device) + + +class DUSt3RBAWrapper(torch.nn.Module): + def __init__( + self, + name, + ckpt_path, + scene_graph="complete", + inference_batch_size=32, + global_optim_schedule="cosine", + global_optim_lr=0.01, + global_optim_niter=300, + **kwargs, + ): + super().__init__() + self.name = name + self.ckpt_path = ckpt_path + self.scene_graph = scene_graph + self.inference_batch_size = inference_batch_size + self.global_optim_schedule = global_optim_schedule + self.global_optim_lr = global_optim_lr + self.global_optim_niter = global_optim_niter + + # Init the model and load the checkpoint + self.model = load_model(self.ckpt_path, device="cpu") + + # Init the global aligner mode + self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer + + def forward(self, views): + """ + Forward pass wrapper for DUSt3R using the global aligner. + + Assumption: + - The batch size of input views is 1. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys, where B is the batch size and is 1: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + + Returns: + List[dict]: A list containing the final outputs for the input views. + """ + # Check the batch size of input views + batch_size_per_view, _, height, width = views[0]["img"].shape + device = views[0]["img"].device + num_views = len(views) + assert batch_size_per_view == 1, ( + f"Batch size of input views should be 1, but got {batch_size_per_view}." + ) + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "DUSt3R expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Convert the input views to the expected input format + images = [] + for view in views: + images.append( + dict( + img=view["img"], + idx=len(images), + instance=str(len(images)), + ) + ) + + # Make image pairs and run inference pair-wise + pairs = make_pairs( + images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + output = inference( + pairs, + self.model, + device, + batch_size=self.inference_batch_size, + verbose=False, + ) + + # Global optimization + with torch.enable_grad(): + scene = global_aligner( + output, device=device, mode=self.global_aligner_mode, verbose=False + ) + _ = scene.compute_global_alignment( + init="mst", + niter=self.global_optim_niter, + schedule=self.global_optim_schedule, + lr=self.global_optim_lr, + ) + + # Make sure scene is not None + if scene is None: + raise RuntimeError("Global optimization failed.") + + # Get the predictions + intrinsics = scene.get_intrinsics() + c2w_poses = scene.get_im_poses() + depths = scene.get_depthmaps() + + # Convert the output to the MapAnything format + with torch.autocast("cuda", enabled=False): + res = [] + for view_idx in range(num_views): + # Get the current view predictions + curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0) + curr_view_pose = c2w_poses[view_idx].unsqueeze(0) + curr_view_depth_z = depths[view_idx].unsqueeze(0) + + # Convert the pose to quaternions and translation + curr_view_cam_translations = curr_view_pose[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) + + # Get the camera frame pointmaps + curr_view_pts3d_cam, _ = depthmap_to_camera_frame( + curr_view_depth_z, curr_view_intrinsic + ) + + # Convert the z depth to depth along ray + curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( + curr_view_depth_z, curr_view_intrinsic + ) + curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) + + # Get the ray directions on the unit sphere in the camera frame + _, curr_view_ray_dirs = get_rays_in_camera_frame( + curr_view_intrinsic, height, width, normalize_to_unit_sphere=True + ) + + # Get the pointmaps + curr_view_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + curr_view_ray_dirs, + curr_view_depth_along_ray, + curr_view_cam_translations, + curr_view_cam_quats, + ) + ) + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d, + "pts3d_cam": curr_view_pts3d_cam, + "ray_directions": curr_view_ray_dirs, + "depth_along_ray": curr_view_depth_along_ray, + "cam_trans": curr_view_cam_translations, + "cam_quats": curr_view_cam_quats, + } + ) + + return res diff --git a/mapanything/models/external/mast3r/__init__.py b/mapanything/models/external/mast3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b064255791b95ffaefa804f4c580cafe37d68300 --- /dev/null +++ b/mapanything/models/external/mast3r/__init__.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for MASt3R + Sparse GA +""" + +import os +import tempfile +import warnings + +import torch +from dust3r.image_pairs import make_pairs +from mast3r.cloud_opt.sparse_ga import sparse_global_alignment +from mast3r.model import load_model + +from mapanything.models.external.vggt.utils.rotation import mat_to_quat +from mapanything.utils.geometry import ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + convert_z_depth_to_depth_along_ray, + depthmap_to_camera_frame, + get_rays_in_camera_frame, +) + + +class MASt3RSGAWrapper(torch.nn.Module): + def __init__( + self, + name, + ckpt_path, + cache_dir, + scene_graph="complete", + sparse_ga_lr1=0.07, + sparse_ga_niter1=300, + sparse_ga_lr2=0.01, + sparse_ga_niter2=300, + sparse_ga_optim_level="refine+depth", + sparse_ga_shared_intrinsics=False, + sparse_ga_matching_conf_thr=5.0, + **kwargs, + ): + super().__init__() + self.name = name + self.ckpt_path = ckpt_path + self.cache_dir = cache_dir + self.scene_graph = scene_graph + self.sparse_ga_lr1 = sparse_ga_lr1 + self.sparse_ga_niter1 = sparse_ga_niter1 + self.sparse_ga_lr2 = sparse_ga_lr2 + self.sparse_ga_niter2 = sparse_ga_niter2 + self.sparse_ga_optim_level = sparse_ga_optim_level + self.sparse_ga_shared_intrinsics = sparse_ga_shared_intrinsics + self.sparse_ga_matching_conf_thr = sparse_ga_matching_conf_thr + + # Init the model and load the checkpoint + self.model = load_model(self.ckpt_path, device="cpu") + + def forward(self, views): + """ + Forward pass wrapper for MASt3R using the sparse global aligner. + + Assumption: + - The batch size of input views is 1. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys, where B is the batch size and is 1: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + "label" (list): ["scene_name"] + "instance" (list): ["image_name"] + + Returns: + List[dict]: A list containing the final outputs for the input views. + """ + # Check the batch size of input views + batch_size_per_view, _, height, width = views[0]["img"].shape + device = views[0]["img"].device + num_views = len(views) + assert batch_size_per_view == 1, ( + f"Batch size of input views should be 1, but got {batch_size_per_view}." + ) + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "MASt3R expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Convert the input views to the expected input format + images = [] + image_paths = [] + for view in views: + images.append( + dict( + img=view["img"].cpu(), + idx=len(images), + instance=str(len(images)), + true_shape=torch.tensor(view["img"].shape[-2:])[None] + .repeat(batch_size_per_view, 1) + .numpy(), + ) + ) + view_name = os.path.join(view["label"][0], view["instance"][0]) + image_paths.append(view_name) + + # Make image pairs and run inference + # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation) + pairs = make_pairs( + images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True + ) + with torch.enable_grad(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + tempfile.mkdtemp(dir=self.cache_dir) + scene = sparse_global_alignment( + image_paths, + pairs, + self.cache_dir, + self.model, + lr1=self.sparse_ga_lr1, + niter1=self.sparse_ga_niter1, + lr2=self.sparse_ga_lr2, + niter2=self.sparse_ga_niter2, + device=device, + opt_depth="depth" in self.sparse_ga_optim_level, + shared_intrinsics=self.sparse_ga_shared_intrinsics, + matching_conf_thr=self.sparse_ga_matching_conf_thr, + verbose=False, + ) + + # Make sure scene is not None + if scene is None: + raise RuntimeError("Global optimization failed.") + + # Get the predictions + intrinsics = scene.intrinsics + c2w_poses = scene.get_im_poses() + _, depths, _ = scene.get_dense_pts3d() + + # Convert the output to the MapAnything format + with torch.autocast("cuda", enabled=False): + res = [] + for view_idx in range(num_views): + # Get the current view predictions + curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0) + curr_view_pose = c2w_poses[view_idx].unsqueeze(0) + curr_view_depth_z = ( + depths[view_idx].reshape((height, width)).unsqueeze(0) + ) + + # Convert the pose to quaternions and translation + curr_view_cam_translations = curr_view_pose[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) + + # Get the camera frame pointmaps + curr_view_pts3d_cam, _ = depthmap_to_camera_frame( + curr_view_depth_z, curr_view_intrinsic + ) + + # Convert the z depth to depth along ray + curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( + curr_view_depth_z, curr_view_intrinsic + ) + curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) + + # Get the ray directions on the unit sphere in the camera frame + _, curr_view_ray_dirs = get_rays_in_camera_frame( + curr_view_intrinsic, height, width, normalize_to_unit_sphere=True + ) + + # Get the pointmaps + curr_view_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + curr_view_ray_dirs, + curr_view_depth_along_ray, + curr_view_cam_translations, + curr_view_cam_quats, + ) + ) + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d, + "pts3d_cam": curr_view_pts3d_cam, + "ray_directions": curr_view_ray_dirs, + "depth_along_ray": curr_view_depth_along_ray, + "cam_trans": curr_view_cam_translations, + "cam_quats": curr_view_cam_quats, + } + ) + + return res diff --git a/mapanything/models/external/moge/__init__.py b/mapanything/models/external/moge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b3de1d11134902026514dca26320a666f8a802 --- /dev/null +++ b/mapanything/models/external/moge/__init__.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for MoGe +""" + +import torch + +from mapanything.models.external.moge.models.v1 import MoGeModel as MoGeModelV1 +from mapanything.models.external.moge.models.v2 import MoGeModel as MoGeModelV2 + + +class MoGeWrapper(torch.nn.Module): + def __init__( + self, + name, + model_string="Ruicheng/moge-2-vitl", + torch_hub_force_reload=False, + load_custom_ckpt=False, + custom_ckpt_path=None, + ): + super().__init__() + self.name = name + self.model_string = model_string + self.torch_hub_force_reload = torch_hub_force_reload + self.load_custom_ckpt = load_custom_ckpt + self.custom_ckpt_path = custom_ckpt_path + + # Mapping of MoGe model version to checkpoint strings + self.moge_model_map = { + "v1": ["Ruicheng/moge-vitl"], + "v2": [ + "Ruicheng/moge-2-vits-normal", + "Ruicheng/moge-2-vitb-normal", + "Ruicheng/moge-2-vitl-normal", + "Ruicheng/moge-2-vitl", + ], + } + + # Initialize the model + if self.model_string in self.moge_model_map["v1"]: + self.model = MoGeModelV1.from_pretrained(self.model_string) + elif self.model_string in self.moge_model_map["v2"]: + self.model = MoGeModelV2.from_pretrained(self.model_string) + else: + raise ValueError( + f"Invalid model string: {self.model_string}. Valid strings are: {self.moge_model_map}" + ) + + # Load custom checkpoint if requested + if self.load_custom_ckpt: + print(f"Loading checkpoint from {self.custom_ckpt_path} ...") + assert self.custom_ckpt_path is not None, ( + "custom_ckpt_path must be provided if load_custom_ckpt is set to True" + ) + custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False) + print(self.model.load_state_dict(custom_ckpt, strict=True)) + del custom_ckpt # in case it occupies memory + + def forward(self, views): + """ + Forward pass wrapper for MoGe-2. + The predicted MoGe-2 mask is not applied to the outputs. + The number of tokens for inference is determined by the image shape. + + Assumption: + - The number of input views is 1. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Length of the list should be 1. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["identity"] + + Returns: + List[dict]: A list containing the final outputs for the single view. Length of the list will be 1. + """ + # Check that the number of input views is 1 + assert len(views) == 1, "MoGe only supports 1 input view." + + # Get input shape of the images, number of tokens for inference, and batch size per view + _, _, height, width = views[0]["img"].shape + num_tokens = int(height // 14) * int(width // 14) + + # Check the data norm type + # MoGe expects a normalized image but without the DINOv2 mean and std applied ("identity") + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "identity", ( + "MoGe expects a normalized image but without the DINOv2 mean and std applied" + ) + + # Run MoGe inference + # Output dict contains: "points", "depth", "mask", "intrinsics", "normal" (based on model config) + model_outputs = self.model.infer( + image=views[0]["img"], num_tokens=num_tokens, apply_mask=False + ) + + # Get the ray directions and depth along ray + with torch.autocast("cuda", enabled=False): + depth_along_ray = torch.norm(model_outputs["points"], dim=-1, keepdim=True) + ray_directions = model_outputs["points"] / depth_along_ray + + # Convert the output to MapAnything format + result_dict = { + "pts3d": model_outputs["points"], + "pts3d_cam": model_outputs["points"], + "depth_z": model_outputs["depth"].unsqueeze(-1), + "intrinsics": model_outputs["intrinsics"], + "non_ambiguous_mask": model_outputs["mask"], + "ray_directions": ray_directions, + "depth_along_ray": depth_along_ray, + } + res = [result_dict] + + return res diff --git a/mapanything/models/external/moge/models/modules.py b/mapanything/models/external/moge/models/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..7a55ea3a93d6a554c0923b6796db3ce489ad73dd --- /dev/null +++ b/mapanything/models/external/moge/models/modules.py @@ -0,0 +1,493 @@ +import functools +import importlib +import itertools +from typing import List, Literal, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mapanything.models.external.dinov2.models.vision_transformer import ( + DinoVisionTransformer, +) +from mapanything.models.external.moge.models.utils import ( + wrap_dinov2_attention_with_sdpa, + wrap_module_with_gradient_checkpointing, +) + + +class ResidualConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = None, + hidden_channels: int = None, + kernel_size: int = 3, + padding_mode: str = "replicate", + activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu", + in_norm: Literal[ + "group_norm", "layer_norm", "instance_norm", "none" + ] = "layer_norm", + hidden_norm: Literal[ + "group_norm", "layer_norm", "instance_norm" + ] = "group_norm", + ): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation == "relu": + activation_cls = nn.ReLU + elif activation == "leaky_relu": + activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2) + elif activation == "silu": + activation_cls = nn.SiLU + elif activation == "elu": + activation_cls = nn.ELU + else: + raise ValueError(f"Unsupported activation function: {activation}") + + self.layers = nn.Sequential( + ( + nn.GroupNorm(in_channels // 32, in_channels) + if in_norm == "group_norm" + else ( + nn.GroupNorm(1, in_channels) + if in_norm == "layer_norm" + else ( + nn.InstanceNorm2d(in_channels) + if in_norm == "instance_norm" + else nn.Identity() + ) + ) + ), + activation_cls(), + nn.Conv2d( + in_channels, + hidden_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode=padding_mode, + ), + ( + nn.GroupNorm(hidden_channels // 32, hidden_channels) + if hidden_norm == "group_norm" + else ( + nn.GroupNorm(1, hidden_channels) + if hidden_norm == "layer_norm" + else ( + nn.InstanceNorm2d(hidden_channels) + if hidden_norm == "instance_norm" + else nn.Identity() + ) + ) + ), + activation_cls(), + nn.Conv2d( + hidden_channels, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode=padding_mode, + ), + ) + + self.skip_connection = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class DINOv2Encoder(nn.Module): + "Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]." + + backbone: DinoVisionTransformer + image_mean: torch.Tensor + image_std: torch.Tensor + dim_features: int + + def __init__( + self, + backbone: str, + intermediate_layers: Union[int, List[int]], + dim_out: int, + **deprecated_kwargs, + ): + super(DINOv2Encoder, self).__init__() + + self.intermediate_layers = intermediate_layers + + # Load the backbone + self.hub_loader = getattr( + importlib.import_module( + "mapanything.models.external.dinov2.hub.backbones", __package__ + ), + backbone, + ) + self.backbone_name = backbone + self.backbone = self.hub_loader(pretrained=False) + + self.dim_features = self.backbone.blocks[0].attn.qkv.in_features + self.num_features = ( + intermediate_layers + if isinstance(intermediate_layers, int) + else len(intermediate_layers) + ) + + self.output_projections = nn.ModuleList( + [ + nn.Conv2d( + in_channels=self.dim_features, + out_channels=dim_out, + kernel_size=1, + stride=1, + padding=0, + ) + for _ in range(self.num_features) + ] + ) + + self.register_buffer( + "image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + ) + self.register_buffer( + "image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + ) + + @property + def onnx_compatible_mode(self): + return getattr(self, "_onnx_compatible_mode", False) + + @onnx_compatible_mode.setter + def onnx_compatible_mode(self, value: bool): + self._onnx_compatible_mode = value + self.backbone.onnx_compatible_mode = value + + def init_weights(self): + pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict() + self.backbone.load_state_dict(pretrained_backbone_state_dict) + + def enable_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def enable_pytorch_native_sdpa(self): + for i in range(len(self.backbone.blocks)): + wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) + + def forward( + self, + image: torch.Tensor, + token_rows: Union[int, torch.LongTensor], + token_cols: Union[int, torch.LongTensor], + return_class_token: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + image_14 = F.interpolate( + image, + (token_rows * 14, token_cols * 14), + mode="bilinear", + align_corners=False, + antialias=not self.onnx_compatible_mode, + ) + image_14 = (image_14 - self.image_mean) / self.image_std + + # Get intermediate layers from the backbone + features = self.backbone.get_intermediate_layers( + image_14, n=self.intermediate_layers, return_class_token=True + ) + + # Project features to the desired dimensionality + x = torch.stack( + [ + proj( + feat.permute(0, 2, 1) + .unflatten(2, (token_rows, token_cols)) + .contiguous() + ) + for proj, (feat, clstoken) in zip(self.output_projections, features) + ], + dim=1, + ).sum(dim=1) + + if return_class_token: + return x, features[-1][1] + else: + return x + + +class Resampler(nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + type_: Literal[ + "pixel_shuffle", + "nearest", + "bilinear", + "conv_transpose", + "pixel_unshuffle", + "avg_pool", + "max_pool", + ], + scale_factor: int = 2, + ): + if type_ == "pixel_shuffle": + nn.Sequential.__init__( + self, + nn.Conv2d( + in_channels, + out_channels * (scale_factor**2), + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + nn.PixelShuffle(scale_factor), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + for i in range(1, scale_factor**2): + self[0].weight.data[i :: scale_factor**2] = self[0].weight.data[ + 0 :: scale_factor**2 + ] + self[0].bias.data[i :: scale_factor**2] = self[0].bias.data[ + 0 :: scale_factor**2 + ] + elif type_ in ["nearest", "bilinear"]: + nn.Sequential.__init__( + self, + nn.Upsample( + scale_factor=scale_factor, + mode=type_, + align_corners=False if type_ == "bilinear" else None, + ), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + elif type_ == "conv_transpose": + nn.Sequential.__init__( + self, + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=scale_factor, + stride=scale_factor, + ), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1] + elif type_ == "pixel_unshuffle": + nn.Sequential.__init__( + self, + nn.PixelUnshuffle(scale_factor), + nn.Conv2d( + in_channels * (scale_factor**2), + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + elif type_ == "avg_pool": + nn.Sequential.__init__( + self, + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor), + ) + elif type_ == "max_pool": + nn.Sequential.__init__( + self, + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor), + ) + else: + raise ValueError(f"Unsupported resampler type: {type_}") + + +class MLP(nn.Sequential): + def __init__(self, dims: Sequence[int]): + nn.Sequential.__init__( + self, + *itertools.chain( + *[ + (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True)) + for dim_in, dim_out in zip(dims[:-2], dims[1:-1]) + ] + ), + nn.Linear(dims[-2], dims[-1]), + ) + + +class ConvStack(nn.Module): + def __init__( + self, + dim_in: List[Optional[int]], + dim_res_blocks: List[int], + dim_out: List[Optional[int]], + resamplers: Union[ + Literal[ + "pixel_shuffle", + "nearest", + "bilinear", + "conv_transpose", + "pixel_unshuffle", + "avg_pool", + "max_pool", + ], + List, + ], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_in_norm: Literal[ + "layer_norm", "group_norm", "instance_norm", "none" + ] = "layer_norm", + res_block_hidden_norm: Literal[ + "layer_norm", "group_norm", "instance_norm", "none" + ] = "group_norm", + activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu", + ): + super().__init__() + self.input_blocks = nn.ModuleList( + [ + ( + nn.Conv2d( + dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0 + ) + if dim_in_ is not None + else nn.Identity() + ) + for dim_in_, dim_res_block_ in zip( + ( + dim_in + if isinstance(dim_in, Sequence) + else itertools.repeat(dim_in) + ), + dim_res_blocks, + ) + ] + ) + self.resamplers = nn.ModuleList( + [ + Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler) + for i, (dim_prev, dim_succ, resampler) in enumerate( + zip( + dim_res_blocks[:-1], + dim_res_blocks[1:], + ( + resamplers + if isinstance(resamplers, Sequence) + else itertools.repeat(resamplers) + ), + ) + ) + ] + ) + self.res_blocks = nn.ModuleList( + [ + nn.Sequential( + *( + ResidualConvBlock( + dim_res_block_, + dim_res_block_, + dim_times_res_block_hidden * dim_res_block_, + activation=activation, + in_norm=res_block_in_norm, + hidden_norm=res_block_hidden_norm, + ) + for _ in range( + num_res_blocks[i] + if isinstance(num_res_blocks, list) + else num_res_blocks + ) + ) + ) + for i, dim_res_block_ in enumerate(dim_res_blocks) + ] + ) + self.output_blocks = nn.ModuleList( + [ + ( + nn.Conv2d( + dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0 + ) + if dim_out_ is not None + else nn.Identity() + ) + for dim_out_, dim_res_block_ in zip( + ( + dim_out + if isinstance(dim_out, Sequence) + else itertools.repeat(dim_out) + ), + dim_res_blocks, + ) + ] + ) + + def enable_gradient_checkpointing(self): + for i in range(len(self.resamplers)): + self.resamplers[i] = wrap_module_with_gradient_checkpointing( + self.resamplers[i] + ) + for i in range(len(self.res_blocks)): + for j in range(len(self.res_blocks[i])): + self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing( + self.res_blocks[i][j] + ) + + def forward(self, in_features: List[torch.Tensor]): + out_features = [] + for i in range(len(self.res_blocks)): + feature = self.input_blocks[i](in_features[i]) + if i == 0: + x = feature + elif feature is not None: + x = x + feature + x = self.res_blocks[i](x) + out_features.append(self.output_blocks[i](x)) + if i < len(self.res_blocks) - 1: + x = self.resamplers[i](x) + return out_features diff --git a/mapanything/models/external/moge/models/utils.py b/mapanything/models/external/moge/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9075226ebc4bf05ce614af287e13f0a9fa84b059 --- /dev/null +++ b/mapanything/models/external/moge/models/utils.py @@ -0,0 +1,477 @@ +import inspect +from functools import partial, wraps +from numbers import Number +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def wrap_module_with_gradient_checkpointing(module: nn.Module): + from torch.utils.checkpoint import checkpoint + + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + +def unwrap_module_with_gradient_checkpointing(module: nn.Module): + module.__class__ = module.__class__._restore_cls + + +def wrap_dinov2_attention_with_sdpa(module: nn.Module): + assert torch.__version__ >= "2.0", "SDPA requires PyTorch 2.0 or later" + + class _AttentionWrapper(module.__class__): + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) # (3, B, H, N, C // H) + + q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + module.__class__ = _AttentionWrapper + return module + + +def sync_ddp_hook( + state, bucket: torch.distributed.GradBucket +) -> torch.futures.Future[torch.Tensor]: + group_to_use = torch.distributed.group.WORLD + world_size = group_to_use.size() + grad = bucket.buffer() + grad.div_(world_size) + torch.distributed.all_reduce(grad, group=group_to_use) + fut = torch.futures.Future() + fut.set_result(grad) + return fut + + +def normalized_view_plane_uv( + width: int, + height: int, + aspect_ratio: float = None, + dtype: torch.dtype = None, + device: torch.device = None, +) -> torch.Tensor: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio**2) ** 0.5 + span_y = 1 / (1 + aspect_ratio**2) ** 0.5 + + u = torch.linspace( + -span_x * (width - 1) / width, + span_x * (width - 1) / width, + width, + dtype=dtype, + device=device, + ) + v = torch.linspace( + -span_y * (height - 1) / height, + span_y * (height - 1) / height, + height, + dtype=dtype, + device=device, + ) + u, v = torch.meshgrid(u, v, indexing="xy") + uv = torch.stack([u, v], dim=-1) + return uv + + +def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal" + from scipy.optimize import least_squares + + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[:, None] + f = (xy_proj * uv).sum() / np.square(xy_proj).sum() + err = (f * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method="lm") + optim_shift = solution["x"].squeeze().astype(np.float32) + + xy_proj = xy / (z + optim_shift)[:, None] + optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum() + + return optim_shift, optim_focal + + +def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift" + from scipy.optimize import least_squares + + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[:, None] + err = (focal * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method="lm") + optim_shift = solution["x"].squeeze().astype(np.float32) + + return optim_shift + + +def recover_focal_shift( + points: torch.Tensor, + mask: torch.Tensor = None, + focal: torch.Tensor = None, + downsample_size: Tuple[int, int] = (64, 64), +): + """ + Recover the depth map and FoV from a point map with unknown z shift and focal. + + Note that it assumes: + - the optical center is at the center of the map + - the map is undistorted + - the map is isometric in the x and y directions + + ### Parameters: + - `points: torch.Tensor` of shape (..., H, W, 3) + - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps. + + ### Returns: + - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map + - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space + """ + shape = points.shape + height, width = points.shape[-3], points.shape[-2] + + points = points.reshape(-1, *shape[-3:]) + mask = None if mask is None else mask.reshape(-1, *shape[-3:-1]) + focal = focal.reshape(-1) if focal is not None else None + uv = normalized_view_plane_uv( + width, height, dtype=points.dtype, device=points.device + ) # (H, W, 2) + + points_lr = F.interpolate( + points.permute(0, 3, 1, 2), downsample_size, mode="nearest" + ).permute(0, 2, 3, 1) + uv_lr = ( + F.interpolate( + uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode="nearest" + ) + .squeeze(0) + .permute(1, 2, 0) + ) + mask_lr = ( + None + if mask is None + else F.interpolate( + mask.to(torch.float32).unsqueeze(1), downsample_size, mode="nearest" + ).squeeze(1) + > 0 + ) + + uv_lr_np = uv_lr.cpu().numpy() + points_lr_np = points_lr.detach().cpu().numpy() + focal_np = focal.cpu().numpy() if focal is not None else None + mask_lr_np = None if mask is None else mask_lr.cpu().numpy() + optim_shift, optim_focal = [], [] + for i in range(points.shape[0]): + points_lr_i_np = ( + points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]] + ) + uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]] + if uv_lr_i_np.shape[0] < 2: + optim_focal.append(1) + optim_shift.append(0) + continue + if focal is None: + optim_shift_i, optim_focal_i = solve_optimal_focal_shift( + uv_lr_i_np, points_lr_i_np + ) + optim_focal.append(float(optim_focal_i)) + else: + optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i]) + optim_shift.append(float(optim_shift_i)) + optim_shift = torch.tensor( + optim_shift, device=points.device, dtype=points.dtype + ).reshape(shape[:-3]) + + if focal is None: + optim_focal = torch.tensor( + optim_focal, device=points.device, dtype=points.dtype + ).reshape(shape[:-3]) + else: + optim_focal = focal.reshape(shape[:-3]) + + return optim_focal, optim_shift + + +def suppress_traceback(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + e.__traceback__ = e.__traceback__.tb_next.tb_next + raise + + return wrapper + + +def get_device(args, kwargs): + device = None + for arg in list(args) + list(kwargs.values()): + if isinstance(arg, torch.Tensor): + if device is None: + device = arg.device + elif device != arg.device: + raise ValueError("All tensors must be on the same device.") + return device + + +def get_args_order(func, args, kwargs): + """ + Get the order of the arguments of a function. + """ + names = inspect.getfullargspec(func).args + names_idx = {name: i for i, name in enumerate(names)} + args_order = [] + kwargs_order = {} + for name, arg in kwargs.items(): + if name in names: + kwargs_order[name] = names_idx[name] + names.remove(name) + for i, arg in enumerate(args): + if i < len(names): + args_order.append(names_idx[names[i]]) + return args_order, kwargs_order + + +def broadcast_args(args, kwargs, args_dim, kwargs_dim): + spatial = [] + for arg, arg_dim in zip( + args + list(kwargs.values()), args_dim + list(kwargs_dim.values()) + ): + if isinstance(arg, torch.Tensor) and arg_dim is not None: + arg_spatial = arg.shape[: arg.ndim - arg_dim] + if len(arg_spatial) > len(spatial): + spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial + for j in range(len(arg_spatial)): + if spatial[-j] < arg_spatial[-j]: + if spatial[-j] == 1: + spatial[-j] = arg_spatial[-j] + else: + raise ValueError("Cannot broadcast arguments.") + for i, arg in enumerate(args): + if isinstance(arg, torch.Tensor) and args_dim[i] is not None: + args[i] = torch.broadcast_to( + arg, [*spatial, *arg.shape[arg.ndim - args_dim[i] :]] + ) + for key, arg in kwargs.items(): + if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: + kwargs[key] = torch.broadcast_to( + arg, [*spatial, *arg.shape[arg.ndim - kwargs_dim[key] :]] + ) + return args, kwargs, spatial + + +@suppress_traceback +def batched(*dims): + """ + Decorator that allows a function to be called with batched arguments. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, device=torch.device("cpu"), **kwargs): + args = list(args) + # get arguments dimensions + args_order, kwargs_order = get_args_order(func, args, kwargs) + args_dim = [dims[i] for i in args_order] + kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} + # convert to torch tensor + device = get_device(args, kwargs) or device + for i, arg in enumerate(args): + if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: + args[i] = torch.tensor(arg, device=device) + for key, arg in kwargs.items(): + if ( + isinstance(arg, (Number, list, tuple)) + and kwargs_dim[key] is not None + ): + kwargs[key] = torch.tensor(arg, device=device) + # broadcast arguments + args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) + for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): + if isinstance(arg, torch.Tensor) and arg_dim is not None: + args[i] = arg.reshape([-1, *arg.shape[arg.ndim - arg_dim :]]) + for key, arg in kwargs.items(): + if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: + kwargs[key] = arg.reshape( + [-1, *arg.shape[arg.ndim - kwargs_dim[key] :]] + ) + # call function + results = func(*args, **kwargs) + type_results = type(results) + results = list(results) if isinstance(results, (tuple, list)) else [results] + # restore spatial dimensions + for i, result in enumerate(results): + results[i] = result.reshape([*spatial, *result.shape[1:]]) + if type_results is tuple: + results = tuple(results) + elif type_results is list: + results = list(results) + else: + results = results[0] + return results + + return wrapper + + return decorator + + +def image_uv( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + device: torch.device = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + """ + Get image space UV grid, ranging in [0, 1]. + + >>> image_uv(10, 10): + [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + + Args: + width (int): image width + height (int): image height + + Returns: + torch.Tensor: shape (height, width, 2) + """ + if left is None: + left = 0 + if top is None: + top = 0 + if right is None: + right = width + if bottom is None: + bottom = height + u = torch.linspace( + (left + 0.5) / width, + (right - 0.5) / width, + right - left, + device=device, + dtype=dtype, + ) + v = torch.linspace( + (top + 0.5) / height, + (bottom - 0.5) / height, + bottom - top, + device=device, + dtype=dtype, + ) + u, v = torch.meshgrid(u, v, indexing="xy") + uv = torch.stack([u, v], dim=-1) + + return uv + + +@batched(2, 1, 2, 2) +def unproject_cv( + uv_coord: torch.Tensor, + depth: torch.Tensor = None, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None, +) -> torch.Tensor: + """ + Unproject uv coordinates to 3D view space following the OpenCV convention + + Args: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (torch.Tensor): [..., N] depth value + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + + Returns: + points (torch.Tensor): [..., N, 3] 3d points + """ + assert intrinsics is not None, "intrinsics matrix is required" + points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1) + points = points @ torch.inverse(intrinsics).transpose(-2, -1) + if depth is not None: + points = points * depth[..., None] + if extrinsics is not None: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3] + return points + + +def depth_to_points( + depth: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor = None +): + height, width = depth.shape[-2:] + uv = image_uv(width=width, height=height, dtype=depth.dtype, device=depth.device) + pts = unproject_cv( + uv, + depth, + intrinsics=intrinsics[..., None, :, :], + extrinsics=extrinsics[..., None, :, :] if extrinsics is not None else None, + ) + + return pts + + +@batched(0, 0, 0, 0, 0, 0) +def intrinsics_from_focal_center( + fx: Union[float, torch.Tensor], + fy: Union[float, torch.Tensor], + cx: Union[float, torch.Tensor], + cy: Union[float, torch.Tensor], +) -> torch.Tensor: + """ + Get OpenCV intrinsics matrix + + Args: + focal_x (float | torch.Tensor): focal length in x axis + focal_y (float | torch.Tensor): focal length in y axis + cx (float | torch.Tensor): principal point in x axis + cy (float | torch.Tensor): principal point in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + N = fx.shape[0] + ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device) + zeros, ones = ( + torch.zeros(N, dtype=fx.dtype, device=fx.device), + torch.ones(N, dtype=fx.dtype, device=fx.device), + ) + ret = torch.stack( + [fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1 + ).unflatten(-1, (3, 3)) + return ret diff --git a/mapanything/models/external/moge/models/v1.py b/mapanything/models/external/moge/models/v1.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2f0d1cee59a94d80f9c5a07bc10e40de36a42f --- /dev/null +++ b/mapanything/models/external/moge/models/v1.py @@ -0,0 +1,597 @@ +import importlib +from numbers import Number +from pathlib import Path +from typing import Any, Dict, IO, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +from huggingface_hub import hf_hub_download + +from mapanything.models.external.moge.models.utils import ( + depth_to_points, + intrinsics_from_focal_center, + normalized_view_plane_uv, + recover_focal_shift, + wrap_module_with_gradient_checkpointing, +) + + +class ResidualConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = None, + hidden_channels: int = None, + padding_mode: str = "replicate", + activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu", + norm: Literal["group_norm", "layer_norm"] = "group_norm", + ): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation == "relu": + activation_cls = lambda: nn.ReLU(inplace=True) # noqa + elif activation == "leaky_relu": + activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) # noqa + elif activation == "silu": + activation_cls = lambda: nn.SiLU(inplace=True) # noqa + elif activation == "elu": + activation_cls = lambda: nn.ELU(inplace=True) # noqa + else: + raise ValueError(f"Unsupported activation function: {activation}") + + self.layers = nn.Sequential( + nn.GroupNorm(1, in_channels), + activation_cls(), + nn.Conv2d( + in_channels, + hidden_channels, + kernel_size=3, + padding=1, + padding_mode=padding_mode, + ), + nn.GroupNorm( + hidden_channels // 32 if norm == "group_norm" else 1, hidden_channels + ), + activation_cls(), + nn.Conv2d( + hidden_channels, + out_channels, + kernel_size=3, + padding=1, + padding_mode=padding_mode, + ), + ) + + self.skip_connection = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class Head(nn.Module): + def __init__( + self, + num_features: int, + dim_in: int, + dim_out: List[int], + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + ): + super().__init__() + + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=dim_proj, + kernel_size=1, + stride=1, + padding=0, + ) + for _ in range(num_features) + ] + ) + + self.upsample_blocks = nn.ModuleList( + [ + nn.Sequential( + self._make_upsampler(in_ch + 2, out_ch), + *( + ResidualConvBlock( + out_ch, + out_ch, + dim_times_res_block_hidden * out_ch, + activation="relu", + norm=res_block_norm, + ) + for _ in range(num_res_blocks) + ), + ) + for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample) + ] + ) + + self.output_block = nn.ModuleList( + [ + self._make_output_block( + dim_upsample[-1] + 2, + dim_out_, + dim_times_res_block_hidden, + last_res_blocks, + last_conv_channels, + last_conv_size, + res_block_norm, + ) + for dim_out_ in dim_out + ] + ) + + def _make_upsampler(self, in_channels: int, out_channels: int): + upsampler = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] + return upsampler + + def _make_output_block( + self, + dim_in: int, + dim_out: int, + dim_times_res_block_hidden: int, + last_res_blocks: int, + last_conv_channels: int, + last_conv_size: int, + res_block_norm: Literal["group_norm", "layer_norm"], + ): + return nn.Sequential( + nn.Conv2d( + dim_in, + last_conv_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + *( + ResidualConvBlock( + last_conv_channels, + last_conv_channels, + dim_times_res_block_hidden * last_conv_channels, + activation="relu", + norm=res_block_norm, + ) + for _ in range(last_res_blocks) + ), + nn.ReLU(inplace=True), + nn.Conv2d( + last_conv_channels, + dim_out, + kernel_size=last_conv_size, + stride=1, + padding=last_conv_size // 2, + padding_mode="replicate", + ), + ) + + def forward(self, hidden_states: torch.Tensor, image: torch.Tensor): + img_h, img_w = image.shape[-2:] + patch_h, patch_w = img_h // 14, img_w // 14 + + # Process the hidden states + x = torch.stack( + [ + proj( + feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous() + ) + for proj, (feat, clstoken) in zip(self.projects, hidden_states) + ], + dim=1, + ).sum(dim=1) + + # Upsample stage + # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8) + for i, block in enumerate(self.upsample_blocks): + # UV coordinates is for awareness of image aspect ratio + uv = normalized_view_plane_uv( + width=x.shape[-1], + height=x.shape[-2], + aspect_ratio=img_w / img_h, + dtype=x.dtype, + device=x.device, + ) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + for layer in block: + x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) + + # (patch_h * 8, patch_w * 8) -> (img_h, img_w) + x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) + uv = normalized_view_plane_uv( + width=x.shape[-1], + height=x.shape[-2], + aspect_ratio=img_w / img_h, + dtype=x.dtype, + device=x.device, + ) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + + if isinstance(self.output_block, nn.ModuleList): + output = [ + torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) + for block in self.output_block + ] + else: + output = torch.utils.checkpoint.checkpoint( + self.output_block, x, use_reentrant=False + ) + + return output + + +class MoGeModel(nn.Module): + image_mean: torch.Tensor + image_std: torch.Tensor + + def __init__( + self, + encoder: str = "dinov2_vitb14", + intermediate_layers: Union[int, List[int]] = 4, + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + remap_output: Literal[ + False, True, "linear", "sinh", "exp", "sinh_exp" + ] = "linear", + res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", + num_tokens_range: Tuple[Number, Number] = [1200, 2500], + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + mask_threshold: float = 0.5, + **deprecated_kwargs, + ): + super(MoGeModel, self).__init__() + + if deprecated_kwargs: + # Process legacy arguments + if "trained_area_range" in deprecated_kwargs: + num_tokens_range = [ + deprecated_kwargs["trained_area_range"][0] // 14**2, + deprecated_kwargs["trained_area_range"][1] // 14**2, + ] + del deprecated_kwargs["trained_area_range"] + # warnings.warn( + # f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}" + # ) + + self.encoder = encoder + self.remap_output = remap_output + self.intermediate_layers = intermediate_layers + self.num_tokens_range = num_tokens_range + self.mask_threshold = mask_threshold + + # NOTE: We have copied the DINOv2 code in torchhub to this repository. + # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues. + hub_loader = getattr( + importlib.import_module( + "mapanything.models.external.dinov2.hub.backbones", __package__ + ), + encoder, + ) + self.backbone = hub_loader(pretrained=False) + dim_feature = self.backbone.blocks[0].attn.qkv.in_features + + self.head = Head( + num_features=( + intermediate_layers + if isinstance(intermediate_layers, int) + else len(intermediate_layers) + ), + dim_in=dim_feature, + dim_out=[3, 1], + dim_proj=dim_proj, + dim_upsample=dim_upsample, + dim_times_res_block_hidden=dim_times_res_block_hidden, + num_res_blocks=num_res_blocks, + res_block_norm=res_block_norm, + last_res_blocks=last_res_blocks, + last_conv_channels=last_conv_channels, + last_conv_size=last_conv_size, + ) + + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, Path, IO[bytes]], + model_kwargs: Optional[Dict[str, Any]] = None, + **hf_kwargs, + ) -> "MoGeModel": + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint = torch.load( + pretrained_model_name_or_path, map_location="cpu", weights_only=True + ) + else: + cached_checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs, + ) + checkpoint = torch.load( + cached_checkpoint_path, map_location="cpu", weights_only=True + ) + model_config = checkpoint["model_config"] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint["model"]) + return model + + def init_weights(self): + "Load the backbone with pretrained dinov2 weights from torch hub" + state_dict = torch.hub.load( + "facebookresearch/dinov2", self.encoder, pretrained=True + ).state_dict() + self.backbone.load_state_dict(state_dict) + + def enable_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing( + self.backbone.blocks[i] + ) + + def _remap_points(self, points: torch.Tensor) -> torch.Tensor: + if self.remap_output == "linear": + pass + elif self.remap_output == "sinh": + points = torch.sinh(points) + elif self.remap_output == "exp": + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output == "sinh_exp": + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + return points + + def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: + original_height, original_width = image.shape[-2:] + + # Resize to expected resolution defined by num_tokens + resize_factor = ( + (num_tokens * 14**2) / (original_height * original_width) + ) ** 0.5 + resized_width, resized_height = ( + int(original_width * resize_factor), + int(original_height * resize_factor), + ) + image = F.interpolate( + image, + (resized_height, resized_width), + mode="bicubic", + align_corners=False, + antialias=True, + ) + + # Apply image transformation for DINOv2 + image = (image - self.image_mean) / self.image_std + image_14 = F.interpolate( + image, + (resized_height // 14 * 14, resized_width // 14 * 14), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # Get intermediate layers from the backbone + features = self.backbone.get_intermediate_layers( + image_14, self.intermediate_layers, return_class_token=True + ) + + # Predict points (and mask) + output = self.head(features, image) + points, mask = output + + # Make sure fp32 precision for output + with torch.autocast(device_type=image.device.type, dtype=torch.float32): + # Resize to original resolution + points = F.interpolate( + points, + (original_height, original_width), + mode="bilinear", + align_corners=False, + antialias=False, + ) + mask = F.interpolate( + mask, + (original_height, original_width), + mode="bilinear", + align_corners=False, + antialias=False, + ) + + # Post-process points and mask + points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1) + points = self._remap_points( + points + ) # slightly improves the performance in case of very large output values + + return_dict = {"points": points, "mask": mask} + return return_dict + + # @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + fov_x: Union[Number, torch.Tensor] = None, + resolution_level: int = 9, + num_tokens: int = None, + apply_mask: bool = True, + force_projection: bool = True, + use_fp16: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + User-friendly inference function + + ### Parameters + - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\ + - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None + - `resolution_level`: An integer [0-9] for the resolution level for inference. + The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. + `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details. + - `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. + `resolution_level` will be ignored if `num_tokens` is provided. Default: None + - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True + - `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True + - `use_fp16`: if True, use mixed precision to speed up inference. Default: True + + ### Returns + + A dictionary containing the following keys: + - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). + - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. + - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. + """ + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + image = image.to(dtype=self.dtype, device=self.device) + + original_height, original_width = image.shape[-2:] + aspect_ratio = original_width / original_height + + if num_tokens is None: + min_tokens, max_tokens = self.num_tokens_range + num_tokens = int( + min_tokens + (resolution_level / 9) * (max_tokens - min_tokens) + ) + + with torch.autocast( + device_type=self.device.type, + dtype=torch.float16, + enabled=use_fp16 and self.dtype != torch.float16, + ): + output = self.forward(image, num_tokens) + points, mask = output["points"], output["mask"] + + # Always process the output in fp32 precision + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + points, mask, fov_x = map( + lambda x: x.float() if isinstance(x, torch.Tensor) else x, + [points, mask, fov_x], + ) + + mask_binary = mask > self.mask_threshold + + # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal) + if fov_x is None: + focal, shift = recover_focal_shift(points, mask_binary) + else: + focal = ( + aspect_ratio + / (1 + aspect_ratio**2) ** 0.5 + / torch.tan( + torch.deg2rad( + torch.as_tensor( + fov_x, device=points.device, dtype=points.dtype + ) + / 2 + ) + ) + ) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + fx = focal / 2 * (1 + aspect_ratio**2) ** 0.5 / aspect_ratio + fy = focal / 2 * (1 + aspect_ratio**2) ** 0.5 + intrinsics = intrinsics_from_focal_center(fx, fy, 0.5, 0.5) + depth = points[..., 2] + shift[..., None, None] + + # If projection constraint is forced, recompute the point map using the actual depth map + if force_projection: + points = depth_to_points(depth, intrinsics=intrinsics) + else: + points = ( + points + + torch.stack( + [torch.zeros_like(shift), torch.zeros_like(shift), shift], + dim=-1, + )[..., None, None, :] + ) + + # Apply mask if needed + if apply_mask: + points = torch.where(mask_binary[..., None], points, torch.inf) + depth = torch.where(mask_binary, depth, torch.inf) + + return_dict = { + "points": points, + "intrinsics": intrinsics, + "depth": depth, + "mask": mask_binary, + } + + if omit_batch_dim: + return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} + + return return_dict diff --git a/mapanything/models/external/moge/models/v2.py b/mapanything/models/external/moge/models/v2.py new file mode 100644 index 0000000000000000000000000000000000000000..521295e4ea208a73b0f2ea57e95d87fb90972aa6 --- /dev/null +++ b/mapanything/models/external/moge/models/v2.py @@ -0,0 +1,385 @@ +import warnings +from numbers import Number +from pathlib import Path +from typing import Any, Dict, IO, List, Literal, Optional, Union + +import torch +import torch.amp +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +from huggingface_hub import hf_hub_download + +from mapanything.models.external.moge.models.modules import ( + ConvStack, + DINOv2Encoder, + MLP, +) +from mapanything.models.external.moge.models.utils import ( + depth_to_points, + intrinsics_from_focal_center, + normalized_view_plane_uv, + recover_focal_shift, +) + + +class MoGeModel(nn.Module): + encoder: DINOv2Encoder + neck: ConvStack + points_head: ConvStack + mask_head: ConvStack + scale_head: MLP + onnx_compatible_mode: bool + + def __init__( + self, + encoder: Dict[str, Any], + neck: Dict[str, Any], + points_head: Dict[str, Any] = None, + mask_head: Dict[str, Any] = None, + normal_head: Dict[str, Any] = None, + scale_head: Dict[str, Any] = None, + remap_output: Literal["linear", "sinh", "exp", "sinh_exp"] = "linear", + num_tokens_range: List[int] = [1200, 3600], + **deprecated_kwargs, + ): + super(MoGeModel, self).__init__() + if deprecated_kwargs: + warnings.warn( + f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}" + ) + + self.remap_output = remap_output + self.num_tokens_range = num_tokens_range + + self.encoder = DINOv2Encoder(**encoder) + self.neck = ConvStack(**neck) + if points_head is not None: + self.points_head = ConvStack(**points_head) + if mask_head is not None: + self.mask_head = ConvStack(**mask_head) + if normal_head is not None: + self.normal_head = ConvStack(**normal_head) + if scale_head is not None: + self.scale_head = MLP(**scale_head) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def onnx_compatible_mode(self) -> bool: + return getattr(self, "_onnx_compatible_mode", False) + + @onnx_compatible_mode.setter + def onnx_compatible_mode(self, value: bool): + self._onnx_compatible_mode = value + self.encoder.onnx_compatible_mode = value + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, Path, IO[bytes]], + model_kwargs: Optional[Dict[str, Any]] = None, + **hf_kwargs, + ) -> "MoGeModel": + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `compiled` + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint_path = pretrained_model_name_or_path + else: + checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs, + ) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + + model_config = checkpoint["model_config"] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint["model"], strict=False) + + return model + + def init_weights(self): + self.encoder.init_weights() + + def enable_gradient_checkpointing(self): + self.encoder.enable_gradient_checkpointing() + self.neck.enable_gradient_checkpointing() + for head in ["points_head", "normal_head", "mask_head"]: + if hasattr(self, head): + getattr(self, head).enable_gradient_checkpointing() + + def enable_pytorch_native_sdpa(self): + self.encoder.enable_pytorch_native_sdpa() + + def _remap_points(self, points: torch.Tensor) -> torch.Tensor: + if self.remap_output == "linear": + pass + elif self.remap_output == "sinh": + points = torch.sinh(points) + elif self.remap_output == "exp": + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output == "sinh_exp": + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + return points + + def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: + batch_size, _, img_h, img_w = image.shape + device, dtype = image.device, image.dtype + + aspect_ratio = img_w / img_h + base_h, base_w = ( + int((num_tokens / aspect_ratio) ** 0.5), + int((num_tokens * aspect_ratio) ** 0.5), + ) + num_tokens = base_h * base_w + + # Backbones encoding + features, cls_token = self.encoder( + image, base_h, base_w, return_class_token=True + ) + features = [features, None, None, None, None] + + # Concat UVs for aspect ratio input + for level in range(5): + uv = normalized_view_plane_uv( + width=base_w * 2**level, + height=base_h * 2**level, + aspect_ratio=aspect_ratio, + dtype=dtype, + device=device, + ) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1) + if features[level] is None: + features[level] = uv + else: + features[level] = torch.concat([features[level], uv], dim=1) + + # Shared neck + features = self.neck(features) + + # Heads decoding + points, normal, mask = ( + getattr(self, head)(features)[-1] if hasattr(self, head) else None + for head in ["points_head", "normal_head", "mask_head"] + ) + metric_scale = ( + self.scale_head(cls_token) if hasattr(self, "scale_head") else None + ) + + # Resize + points, normal, mask = ( + ( + F.interpolate( + v, + (img_h, img_w), + mode="bilinear", + align_corners=False, + antialias=False, + ) + if v is not None + else None + ) + for v in [points, normal, mask] + ) + + # Remap output + if points is not None: + points = points.permute(0, 2, 3, 1) + points = self._remap_points( + points + ) # slightly improves the performance in case of very large output values + if normal is not None: + normal = normal.permute(0, 2, 3, 1) + normal = F.normalize(normal, dim=-1) + if mask is not None: + mask = mask.squeeze(1).sigmoid() + if metric_scale is not None: + metric_scale = metric_scale.squeeze(1).exp() + + return_dict = { + "points": points, + "normal": normal, + "mask": mask, + "metric_scale": metric_scale, + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + return return_dict + + # @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + num_tokens: int = None, + resolution_level: int = 9, + force_projection: bool = True, + apply_mask: Literal[False, True, "blend"] = True, + fov_x: Optional[Union[Number, torch.Tensor]] = None, + use_fp16: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + User-friendly inference function + + ### Parameters + - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W) + - `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500. + More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`. + - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True + - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True + - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None + - `use_fp16`: if True, use mixed precision to speed up inference. Default: True + + ### Returns + + A dictionary containing the following keys: + - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). + - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. + - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. + """ + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + image = image.to(dtype=self.dtype, device=self.device) + + original_height, original_width = image.shape[-2:] + aspect_ratio = original_width / original_height + + # Determine the number of base tokens to use + if num_tokens is None: + min_tokens, max_tokens = self.num_tokens_range + num_tokens = int( + min_tokens + (resolution_level / 9) * (max_tokens - min_tokens) + ) + + # Forward pass + with torch.autocast( + device_type=self.device.type, + dtype=torch.float16, + enabled=use_fp16 and self.dtype != torch.float16, + ): + output = self.forward(image, num_tokens=num_tokens) + points, normal, mask, metric_scale = ( + output.get(k, None) for k in ["points", "normal", "mask", "metric_scale"] + ) + + # Always process the output in fp32 precision + points, normal, mask, metric_scale, fov_x = map( + lambda x: x.float() if isinstance(x, torch.Tensor) else x, + [points, normal, mask, metric_scale, fov_x], + ) + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + if mask is not None: + mask_binary = mask > 0.5 + else: + mask_binary = None + + if points is not None: + # Convert affine point map to camera-space. Recover depth and intrinsics from point map. + # NOTE: Focal here is the focal length relative to half the image diagonal + if fov_x is None: + # Recover focal and shift from predicted point map + focal, shift = recover_focal_shift(points, mask_binary) + else: + # Focal is known, recover shift only + focal = ( + aspect_ratio + / (1 + aspect_ratio**2) ** 0.5 + / torch.tan( + torch.deg2rad( + torch.as_tensor( + fov_x, device=points.device, dtype=points.dtype + ) + / 2 + ) + ) + ) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + fx, fy = ( + focal / 2 * (1 + aspect_ratio**2) ** 0.5 / aspect_ratio, + focal / 2 * (1 + aspect_ratio**2) ** 0.5, + ) + intrinsics = intrinsics_from_focal_center(fx, fy, 0.5, 0.5) + points[..., 2] += shift[..., None, None] + if mask_binary is not None: + mask_binary &= ( + points[..., 2] > 0 + ) # in case depth is contains negative values (which should never happen in practice) + depth = points[..., 2].clone() + else: + depth, intrinsics = None, None + + # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics + if force_projection and depth is not None: + points = depth_to_points(depth, intrinsics=intrinsics) + + # Apply metric scale + if metric_scale is not None: + if points is not None: + points *= metric_scale[:, None, None, None] + if depth is not None: + depth *= metric_scale[:, None, None] + + # Apply mask + if apply_mask and mask_binary is not None: + points = ( + torch.where(mask_binary[..., None], points, torch.inf) + if points is not None + else None + ) + depth = ( + torch.where(mask_binary, depth, torch.inf) + if depth is not None + else None + ) + normal = ( + torch.where( + mask_binary[..., None], normal, torch.zeros_like(normal) + ) + if normal is not None + else None + ) + + return_dict = { + "points": points, + "intrinsics": intrinsics, + "depth": depth, + "mask": mask_binary, + "normal": normal, + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + if omit_batch_dim: + return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} + + return return_dict diff --git a/mapanything/models/external/must3r/__init__.py b/mapanything/models/external/must3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022036a34727844f4886b382740747b7e509dc6e --- /dev/null +++ b/mapanything/models/external/must3r/__init__.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for MUSt3R +""" + +import datetime +import os + +import numpy as np +import torch +from dust3r.viz import rgb +from must3r.demo.inference import SceneState +from must3r.engine.inference import inference_multi_ar, postprocess +from must3r.model import get_pointmaps_activation, load_model + +from mapanything.models.external.vggt.utils.rotation import mat_to_quat + + +def must3r_inference( + views, + filelist, + model, + retrieval, + device, + amp, + num_mem_images, + max_bs, + init_num_images=2, + batch_num_views=1, + render_once=False, + is_sequence=False, + viser_server=None, + num_refinements_iterations=2, + verbose=True, +): + if amp == "fp16": + dtype = torch.float16 + elif amp == "bf16": + assert torch.cuda.is_bf16_supported() + dtype = torch.bfloat16 + else: + assert not amp + dtype = torch.float32 + + max_bs = None if max_bs == 0 else max_bs + encoder, decoder = model + pointmaps_activation = get_pointmaps_activation(decoder, verbose=verbose) + + def post_process_function(x): + return postprocess( + x, pointmaps_activation=pointmaps_activation, compute_cam=True + ) + + if verbose: + print("loading images") + time_start = datetime.datetime.now() + nimgs = len(views) + + ellapsed = datetime.datetime.now() - time_start + if verbose: + print(f"loaded in {ellapsed}") + print("running inference") + time_start = datetime.datetime.now() + if viser_server is not None: + viser_server.reset(nimgs) + + imgs = [b["img"].to("cpu") for b in views] + true_shape = [torch.from_numpy(b["true_shape"]).to("cpu") for b in views] + true_shape = torch.stack(true_shape, dim=0) + nimgs = true_shape.shape[0] + + # Use all images as keyframes + keyframes = np.linspace(0, len(imgs) - 1, num_mem_images, dtype=int).tolist() + encoder_precomputed_features = None + + not_keyframes = sorted(set(range(nimgs)).difference(set(keyframes))) + assert (len(keyframes) + len(not_keyframes)) == nimgs + # reorder images + views = [views[i] for i in keyframes] + [views[i] for i in not_keyframes] + imgs = [b["img"].to(device) for b in views] + true_shape = [torch.from_numpy(b["true_shape"]).to(device) for b in views] + filenames = [filelist[i] for i in keyframes + not_keyframes] + img_ids = [torch.tensor(v) for v in keyframes + not_keyframes] + + if encoder_precomputed_features is not None: + x_start, pos_start = encoder_precomputed_features + x = [x_start[i] for i in keyframes] + [x_start[i] for i in not_keyframes] + pos = [pos_start[i] for i in keyframes] + [pos_start[i] for i in not_keyframes] + encoder_precomputed_features = (x, pos) + + mem_batches = [init_num_images] + while (sum_b := sum(mem_batches)) != max(num_mem_images, init_num_images): + size_b = min(batch_num_views, num_mem_images - sum_b) + mem_batches.append(size_b) + + if render_once: + to_render = list(range(num_mem_images, nimgs)) + else: + to_render = None + + with torch.autocast("cuda", dtype=dtype): + x_out_0, x_out = inference_multi_ar( + encoder, + decoder, + imgs, + img_ids, + true_shape, + mem_batches, + max_bs=max_bs, + verbose=verbose, + to_render=to_render, + encoder_precomputed_features=encoder_precomputed_features, + device=device, + preserve_gpu_mem=True, + post_process_function=post_process_function, + viser_server=viser_server, + num_refinements_iterations=num_refinements_iterations, + ) + if to_render is not None: + x_out = x_out_0 + x_out + + ellapsed = datetime.datetime.now() - time_start + if verbose: + print(f"inference in {ellapsed}") + try: + print(str(int(torch.cuda.max_memory_reserved(device) / (1024**2))) + " MB") + except Exception: + pass + + if viser_server is not None: + viser_server.reset_cam_visility() + viser_server.send_message("Finished") + + if verbose: + print("preparing pointcloud") + time_start = datetime.datetime.now() + focals = [] + cams2world = [] + for i in range(nimgs): + focals.append(float(x_out[i]["focal"].cpu())) + cams2world.append(x_out[i]["c2w"].cpu()) + + # x_out to cpu + for i in range(len(x_out)): + for k in x_out[i].keys(): + x_out[i][k] = x_out[i][k].cpu() + + rgbimg = [rgb(imgs[i], true_shape[i]) for i in range(nimgs)] + scene = SceneState(x_out, rgbimg, true_shape, focals, cams2world, filenames) + + ellapsed = datetime.datetime.now() - time_start + if verbose: + print(f"pointcloud prepared in {ellapsed}") + + return scene + + +class MUSt3RWrapper(torch.nn.Module): + def __init__( + self, + name, + ckpt_path, + retrieval_ckpt_path, + img_size=512, + amp="bf16", + max_bs=1, + **kwargs, + ): + super().__init__() + self.name = name + self.ckpt_path = ckpt_path + self.retrieval_ckpt_path = retrieval_ckpt_path + self.amp = amp + self.max_bs = max_bs + + # Init the model and load the checkpoint + self.model = load_model(self.ckpt_path, img_size=512) + + def forward(self, views): + """ + Forward pass wrapper for MUSt3R. + + Assumption: + - The batch size of input views is 1. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys, where B is the batch size and is 1: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + "label" (list): ["scene_name"] + "instance" (list): ["image_name"] + + Returns: + List[dict]: A list containing the final outputs for the input views. + """ + # Check the batch size of input views + batch_size_per_view, _, height, width = views[0]["img"].shape + device = views[0]["img"].device + num_views = len(views) + assert batch_size_per_view == 1, ( + f"Batch size of input views should be 1, but got {batch_size_per_view}." + ) + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "MUSt3R expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Convert the input views to the expected input format + images = [] + image_paths = [] + for view in views: + images.append( + dict( + img=view["img"][0].cpu(), + idx=len(images), + instance=str(len(images)), + true_shape=np.int32([view["img"].shape[-2], view["img"].shape[-1]]), + ) + ) + view_name = os.path.join(view["label"][0], view["instance"][0]) + image_paths.append(view_name) + + # Run MUSt3R inference + scene = must3r_inference( + images, + image_paths, + self.model, + self.retrieval_ckpt_path, + device, + self.amp, + num_views, + self.max_bs, + verbose=False, + ) + + # Make sure scene is not None + if scene is None: + raise RuntimeError("MUSt3R failed.") + + # Get the predictions + predictions = scene.x_out + + # Convert the output to the MapAnything format + with torch.autocast("cuda", enabled=False): + res = [] + for view_idx in range(num_views): + # Get the current view predictions + curr_view_prediction = predictions[view_idx] + curr_view_conf = curr_view_prediction["conf"] + curr_view_pose = curr_view_prediction["c2w"].unsqueeze(0) + + # Convert the pose to quaternions and translation + curr_view_cam_translations = curr_view_pose[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) + + # Get the camera frame pointmaps + curr_view_pts3d_cam = curr_view_prediction["pts3d_local"].unsqueeze(0) + + # Get the depth along ray and ray directions + curr_view_depth_along_ray = torch.norm( + curr_view_pts3d_cam, dim=-1, keepdim=True + ) + curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray + + # Get the pointmaps + curr_view_pts3d = curr_view_prediction["pts3d"].unsqueeze(0) + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d.to(device), + "pts3d_cam": curr_view_pts3d_cam.to(device), + "ray_directions": curr_view_ray_dirs.to(device), + "depth_along_ray": curr_view_depth_along_ray.to(device), + "cam_trans": curr_view_cam_translations.to(device), + "cam_quats": curr_view_cam_quats.to(device), + "conf": curr_view_conf.to(device), + } + ) + + return res diff --git a/mapanything/models/external/pi3/__init__.py b/mapanything/models/external/pi3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3814640daf72b98e699f70a250215cd34039dadf --- /dev/null +++ b/mapanything/models/external/pi3/__init__.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for Pi3 +""" + +import torch + +from mapanything.models.external.pi3.models.pi3 import Pi3 +from mapanything.models.external.vggt.utils.rotation import mat_to_quat + + +class Pi3Wrapper(torch.nn.Module): + def __init__( + self, + name, + torch_hub_force_reload, + load_pretrained_weights=True, + pos_type="rope100", + decoder_size="large", + ): + super().__init__() + self.name = name + self.torch_hub_force_reload = torch_hub_force_reload + + if load_pretrained_weights: + # Load pre-trained weights + if not torch_hub_force_reload: + # Initialize the Pi3 model from huggingface hub cache + print("Loading Pi3 from huggingface cache ...") + self.model = Pi3.from_pretrained( + "yyfz233/Pi3", + ) + else: + # Initialize the Pi3 model + self.model = Pi3.from_pretrained("yyfz233/Pi3", force_download=True) + else: + # Load the Pi3 class + self.model = Pi3( + pos_type=pos_type, + decoder_size=decoder_size, + ) + + # Get the dtype for Pi3 inference + # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) + self.dtype = ( + torch.bfloat16 + if torch.cuda.get_device_capability()[0] >= 8 + else torch.float16 + ) + + def forward(self, views): + """ + Forward pass wrapper for Pi3 + + Assumption: + - All the input views have the same image shape. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["identity"] + + Returns: + List[dict]: A list containing the final outputs for all N views. + """ + # Get input shape of the images, number of views, and batch size per view + batch_size_per_view, _, height, width = views[0]["img"].shape + num_views = len(views) + + # Check the data norm type + # Pi3 expects a normalized image but without the DINOv2 mean and std applied ("identity") + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "identity", ( + "Pi3 expects a normalized image but without the DINOv2 mean and std applied" + ) + + # Concatenate the images to create a single (B, V, C, H, W) tensor + img_list = [view["img"] for view in views] + images = torch.stack(img_list, dim=1) + + # Run the Pi3 aggregator + with torch.autocast("cuda", dtype=self.dtype): + results = self.model(images) + + # Need high precision for transformations + with torch.autocast("cuda", enabled=False): + # Convert the output to MapAnything format + res = [] + for view_idx in range(num_views): + # Get the extrinsics + curr_view_extrinsic = results["camera_poses"][:, view_idx, ...] + curr_view_cam_translations = curr_view_extrinsic[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3]) + + # Get the depth along ray, ray directions, local point cloud & global point cloud + curr_view_pts3d_cam = results["local_points"][:, view_idx, ...] + curr_view_depth_along_ray = torch.norm( + curr_view_pts3d_cam, dim=-1, keepdim=True + ) + curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray + curr_view_pts3d = results["points"][:, view_idx, ...] + + # Get the confidence + curr_view_confidence = results["conf"][:, view_idx, ...] + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d, + "pts3d_cam": curr_view_pts3d_cam, + "ray_directions": curr_view_ray_dirs, + "depth_along_ray": curr_view_depth_along_ray, + "cam_trans": curr_view_cam_translations, + "cam_quats": curr_view_cam_quats, + "conf": curr_view_confidence, + } + ) + + return res diff --git a/mapanything/models/external/pi3/layers/__init__.py b/mapanything/models/external/pi3/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/pi3/layers/attention.py b/mapanything/models/external/pi3/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..28ca9413f3aa61bf55c88c29e734b230844f6bb6 --- /dev/null +++ b/mapanything/models/external/pi3/layers/attention.py @@ -0,0 +1,429 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + + +import os + +import torch +from torch import nn, Tensor +from torch.nn.attention import SDPBackend +from torch.nn.functional import scaled_dot_product_attention + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:, :, i] for i in range(3)] + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FlashAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .transpose(1, 3) + ) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:, :, i] for i in range(3)] + + if q.dtype == torch.bfloat16: + with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v) + else: + with nn.attention.sdpa_kernel( + [SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION] + ): + x = scaled_dot_product_attention(q, k, v) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +""" +Following is written by GPT-4o +""" + + +class CrossAttentionRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + qk_norm: bool = False, + norm_layer: nn.Module = nn.LayerNorm, + rope=None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + # Separate projection layers for query, key, and value + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) + + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias=None, + qpos=None, + kpos=None, + ) -> Tensor: + """ + Args: + query: Tensor of shape (B, N, C), input query + key: Tensor of shape (B, M, C), input key + value: Tensor of shape (B, M, C), input value + attn_bias: Optional tensor for attention bias + Returns: + Tensor of shape (B, N, C), output of cross-attention + """ + B, N, C = query.shape + _, M, _ = key.shape + + # Project query, key, and value + q = ( + self.q_proj(query) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = ( + self.k_proj(key) + .reshape(B, M, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + v = ( + self.v_proj(value) + .reshape(B, M, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + # Scale query + q = q * self.scale + + # Compute attention scores + attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M) + if attn_bias is not None: + attn = attn + attn_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # Compute attention output + x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C) + + # Final projection + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffCrossAttentionRope(CrossAttentionRope): + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias=None, + qpos=None, + kpos=None, + ) -> Tensor: + """ + Args: + query: Tensor of shape (B, N, C), input query + key: Tensor of shape (B, M, C), input key + value: Tensor of shape (B, M, C), input value + attn_bias: Optional tensor for attention bias + Returns: + Tensor of shape (B, N, C), output of cross-attention + """ + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(query, key, value, attn_bias) + + B, N, C = query.shape + _, M, _ = key.shape + + # Project query, key, and value + q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads) + k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads) + v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + # Compute memory-efficient attention + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape(B, N, C) + + # Final projection + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class AttentionRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + qk_norm: bool = False, + norm_layer: nn.Module = nn.LayerNorm, + rope=None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + + self.rope = rope + + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:, :, i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix + # global_valid_id = torch.where(score_matrix > 0) + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FlashAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .transpose(1, 3) + ) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:, :, i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + if q.dtype == torch.bfloat16: + with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v) + else: + with nn.attention.sdpa_kernel( + [SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION] + ): + x = scaled_dot_product_attention(q, k, v) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def get_attn_score(blk_class, x, frame_num, token_length, xpos=None): + x = blk_class.norm1(x) + + B, N, C = x.shape + qkv = blk_class.attn.qkv(x).reshape( + B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads + ) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:, :, i] for i in range(3)] + q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype) + + if blk_class.attn.rope is not None: + q = blk_class.attn.rope(q, xpos) + k = blk_class.attn.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + score = ( + ( + q.permute(0, 2, 1, 3) + * blk_class.attn.scale + @ k.permute(0, 2, 1, 3).transpose(-2, -1) + ) + .sum(dim=1) + .reshape(B, frame_num, token_length, frame_num, token_length) + .mean(dim=[2, 4]) + .sum(-1) + ) + + return score diff --git a/mapanything/models/external/pi3/layers/block.py b/mapanything/models/external/pi3/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..1eac717874a6b10dee909c88c16faa8e23059525 --- /dev/null +++ b/mapanything/models/external/pi3/layers/block.py @@ -0,0 +1,448 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import os +from typing import Any, Callable, Dict, List, Tuple + +import torch +from torch import nn, Tensor + +from mapanything.models.external.dinov2.layers.drop_path import DropPath +from mapanything.models.external.dinov2.layers.layer_scale import LayerScale +from mapanything.models.external.dinov2.layers.mlp import Mlp +from mapanything.models.external.pi3.layers.attention import ( + Attention, + CrossAttentionRope, + MemEffAttention, +) + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, index_select_cat, scaled_index_add + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + x_plus_residual = scaled_index_add( + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + + +class BlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + rope=None, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + rope=rope, + ) + + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, xpos=None) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x), xpos=xpos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +class CrossBlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope, + ffn_layer: Callable[..., nn.Module] = Mlp, + init_values=None, + qk_norm: bool = False, + rope=None, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + rope=rope, + qk_norm=qk_norm, + ) + + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.ls_y = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.norm2 = norm_layer(dim) + self.norm_y = norm_layer(dim) + self.cross_attn = cross_attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + rope=rope, + qk_norm=qk_norm, + ) + + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + bias=ffn_bias, + ) + + def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x), xpos=xpos)) + + def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor: + return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm3(x))) + + x = x + attn_residual_func(x) + y_ = self.norm_y(y) + x = x + cross_attn_residual_func(x, y_) + x = x + ffn_residual_func(x) + + return x diff --git a/mapanything/models/external/pi3/layers/camera_head.py b/mapanything/models/external/pi3/layers/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..37e12354b2e7a5a799c883b12710d6528f39241b --- /dev/null +++ b/mapanything/models/external/pi3/layers/camera_head.py @@ -0,0 +1,106 @@ +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172' +class ResConvBlock(nn.Module): + """ + 1x1 convolution residual block + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.head_skip = ( + nn.Identity() + if self.in_channels == self.out_channels + else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + ) + # self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + # self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + # self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + + # change 1x1 convolution to linear + self.res_conv1 = nn.Linear(self.in_channels, self.out_channels) + self.res_conv2 = nn.Linear(self.out_channels, self.out_channels) + self.res_conv3 = nn.Linear(self.out_channels, self.out_channels) + + def forward(self, res): + x = F.relu(self.res_conv1(res)) + x = F.relu(self.res_conv2(x)) + x = F.relu(self.res_conv3(x)) + res = self.head_skip(res) + x + return res + + +class CameraHead(nn.Module): + def __init__(self, dim=512): + super().__init__() + output_dim = dim + self.res_conv = nn.ModuleList( + [deepcopy(ResConvBlock(output_dim, output_dim)) for _ in range(2)] + ) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.more_mlps = nn.Sequential( + nn.Linear(output_dim, output_dim), + nn.ReLU(), + nn.Linear(output_dim, output_dim), + nn.ReLU(), + ) + self.fc_t = nn.Linear(output_dim, 3) + self.fc_rot = nn.Linear(output_dim, 9) + + def forward(self, feat, patch_h, patch_w): + BN, hw, c = feat.shape + + for i in range(2): + feat = self.res_conv[i](feat) + + # feat = self.avgpool(feat) + feat = self.avgpool( + feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous() + ) ########## + feat = feat.view(feat.size(0), -1) + + feat = self.more_mlps(feat) # [B, D_] + with torch.amp.autocast(device_type="cuda", enabled=False): + out_t = self.fc_t(feat.float()) # [B,3] + out_r = self.fc_rot(feat.float()) # [B,9] + pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device) + + return pose + + def convert_pose_to_4x4(self, B, out_r, out_t, device): + out_r = self.svd_orthogonalize(out_r) # [N,3,3] + pose = torch.zeros((B, 4, 4), device=device) + pose[:, :3, :3] = out_r + pose[:, :3, 3] = out_t + pose[:, 3, 3] = 1.0 + return pose + + def svd_orthogonalize(self, m): + """Convert 9D representation to SO(3) using SVD orthogonalization. + + Args: + m: [BATCH, 3, 3] 3x3 matrices. + + Returns: + [BATCH, 3, 3] SO(3) rotation matrices. + """ + if m.dim() < 3: + m = m.reshape((-1, 3, 3)) + m_transpose = torch.transpose( + torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2 + ) + u, s, v = torch.svd(m_transpose) + det = torch.det(torch.matmul(v, u.transpose(-2, -1))) + # Check orientation reflection. + r = torch.matmul( + torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2), + u.transpose(-2, -1), + ) + return r diff --git a/mapanything/models/external/pi3/layers/pos_embed.py b/mapanything/models/external/pi3/layers/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd500f4b353646c3ba8cafe2598186e8c97031e --- /dev/null +++ b/mapanything/models/external/pi3/layers/pos_embed.py @@ -0,0 +1,190 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + + +import numpy as np +import torch + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if n_cls_token > 0: + pos_embed = np.concatenate( + [np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +# ---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +# ---------------------------------------------------------- + +try: + from models.curope import cuRoPE2D + + RoPE2D = cuRoPE2D +except ImportError: + + class RoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, D, 2).float().to(device) / D) + ) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3) % 2 == 0, ( + "number of dimensions should be a multiple of two" + ) + D = tokens.size(3) // 2 + assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin( + D, int(positions.max()) + 1, tokens.device, tokens.dtype + ) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:, :, 0], cos, sin) + x = self.apply_rope1d(x, positions[:, :, 1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens + + +# patch embedding +class PositionGetter(object): + """return positions of patches""" + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if (h, w) not in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone() + return pos diff --git a/mapanything/models/external/pi3/layers/transformer_head.py b/mapanything/models/external/pi3/layers/transformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d411938709d11c68b6618bff9205d001feb91829 --- /dev/null +++ b/mapanything/models/external/pi3/layers/transformer_head.py @@ -0,0 +1,98 @@ +from functools import partial + +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from mapanything.models.external.dinov2.layers import Mlp +from mapanything.models.external.pi3.layers.attention import FlashAttentionRope +from mapanything.models.external.pi3.layers.block import BlockRope + + +class TransformerDecoder(nn.Module): + def __init__( + self, + in_dim, + out_dim, + dec_embed_dim=512, + depth=5, + dec_num_heads=8, + mlp_ratio=4, + rope=None, + need_project=True, + use_checkpoint=False, + ): + super().__init__() + + self.projects = ( + nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity() + ) + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList( + [ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=None, + qk_norm=False, + # attn_class=MemEffAttentionRope, + attn_class=FlashAttentionRope, + rope=rope, + ) + for _ in range(depth) + ] + ) + + self.linear_out = nn.Linear(dec_embed_dim, out_dim) + + def forward(self, hidden, xpos=None): + hidden = self.projects(hidden) + for i, blk in enumerate(self.blocks): + if self.use_checkpoint and self.training: + hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False) + else: + hidden = blk(hidden, xpos=xpos) + out = self.linear_out(hidden) + return out + + +class LinearPts3d(nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__( + self, + patch_size, + dec_embed_dim, + output_dim=3, + ): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(dec_embed_dim, (output_dim) * self.patch_size**2) + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + # extract 3D points + feat = self.proj(tokens) # B,S,D + feat = feat.transpose(-1, -2).view( + B, -1, H // self.patch_size, W // self.patch_size + ) + feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W + + # permute + norm depth + return feat.permute(0, 2, 3, 1) diff --git a/mapanything/models/external/pi3/models/__init__.py b/mapanything/models/external/pi3/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/pi3/models/pi3.py b/mapanything/models/external/pi3/models/pi3.py new file mode 100644 index 0000000000000000000000000000000000000000..1b79e71bf21e468f966eb11841febf6b23ec7160 --- /dev/null +++ b/mapanything/models/external/pi3/models/pi3.py @@ -0,0 +1,251 @@ +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin + +from mapanything.models.external.dinov2.hub.backbones import dinov2_vitl14_reg +from mapanything.models.external.dinov2.layers import Mlp +from mapanything.models.external.pi3.layers.attention import FlashAttentionRope +from mapanything.models.external.pi3.layers.block import BlockRope +from mapanything.models.external.pi3.layers.camera_head import CameraHead +from mapanything.models.external.pi3.layers.pos_embed import PositionGetter, RoPE2D +from mapanything.models.external.pi3.layers.transformer_head import ( + LinearPts3d, + TransformerDecoder, +) + + +def homogenize_points( + points, +): + """Convert batched points (xyz) to (xyz1).""" + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +class Pi3(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + pos_type="rope100", + decoder_size="large", + ): + super().__init__() + + # ---------------------- + # Encoder + # ---------------------- + self.encoder = dinov2_vitl14_reg(pretrained=False) + self.patch_size = 14 + del self.encoder.mask_token + + # ---------------------- + # Positonal Encoding + # ---------------------- + self.pos_type = pos_type if pos_type is not None else "none" + self.rope = None + if self.pos_type.startswith("rope"): # eg rope100 + if RoPE2D is None: + raise ImportError( + "Cannot find cuRoPE2D, please install it following the README instructions" + ) + freq = float(self.pos_type[len("rope") :]) + self.rope = RoPE2D(freq=freq) + self.position_getter = PositionGetter() + else: + raise NotImplementedError + + # ---------------------- + # Decoder + # ---------------------- + if decoder_size == "small": + dec_embed_dim = 384 + dec_num_heads = 6 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == "base": + dec_embed_dim = 768 + dec_num_heads = 12 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == "large": + dec_embed_dim = 1024 + dec_num_heads = 16 + mlp_ratio = 4 + dec_depth = 36 + else: + raise NotImplementedError + self.decoder = nn.ModuleList( + [ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=0.01, + qk_norm=True, + attn_class=FlashAttentionRope, + rope=self.rope, + ) + for _ in range(dec_depth) + ] + ) + self.dec_embed_dim = dec_embed_dim + + # ---------------------- + # Register_token + # ---------------------- + num_register_tokens = 5 + self.patch_start_idx = num_register_tokens + self.register_token = nn.Parameter( + torch.randn(1, 1, num_register_tokens, self.dec_embed_dim) + ) + nn.init.normal_(self.register_token, std=1e-6) + + # ---------------------- + # Local Points Decoder + # ---------------------- + self.point_decoder = TransformerDecoder( + in_dim=2 * self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, + out_dim=1024, + rope=self.rope, + ) + self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3) + + # ---------------------- + # Conf Decoder + # ---------------------- + self.conf_decoder = deepcopy(self.point_decoder) + self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1) + + # ---------------------- + # Camera Pose Decoder + # ---------------------- + self.camera_decoder = TransformerDecoder( + in_dim=2 * self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, # 8 + out_dim=512, + rope=self.rope, + use_checkpoint=False, + ) + self.camera_head = CameraHead(dim=512) + + # For ImageNet Normalize + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + def decode(self, hidden, N, H, W): + BN, hw, _ = hidden.shape + B = BN // N + + final_output = [] + + hidden = hidden.reshape(B * N, hw, -1) + + register_token = self.register_token.repeat(B, N, 1, 1).reshape( + B * N, *self.register_token.shape[-2:] + ) + + # Concatenate special tokens with patch tokens + hidden = torch.cat([register_token, hidden], dim=1) + hw = hidden.shape[1] + + if self.pos_type.startswith("rope"): + pos = self.position_getter( + B * N, H // self.patch_size, W // self.patch_size, hidden.device + ) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = ( + torch.zeros(B * N, self.patch_start_idx, 2) + .to(hidden.device) + .to(pos.dtype) + ) + pos = torch.cat([pos_special, pos], dim=1) + + for i in range(len(self.decoder)): + blk = self.decoder[i] + + if i % 2 == 0: + pos = pos.reshape(B * N, hw, -1) + hidden = hidden.reshape(B * N, hw, -1) + else: + pos = pos.reshape(B, N * hw, -1) + hidden = hidden.reshape(B, N * hw, -1) + + hidden = blk(hidden, xpos=pos) + + if i + 1 in [len(self.decoder) - 1, len(self.decoder)]: + final_output.append(hidden.reshape(B * N, hw, -1)) + + return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape( + B * N, hw, -1 + ) + + def forward(self, imgs): + imgs = (imgs - self.image_mean) / self.image_std + + B, N, _, H, W = imgs.shape + patch_h, patch_w = H // 14, W // 14 + + # encode by dinov2 + imgs = imgs.reshape(B * N, _, H, W) + hidden = self.encoder(imgs, is_training=True) + + if isinstance(hidden, dict): + hidden = hidden["x_norm_patchtokens"] + + hidden, pos = self.decode(hidden, N, H, W) + + point_hidden = self.point_decoder(hidden, xpos=pos) + conf_hidden = self.conf_decoder(hidden, xpos=pos) + camera_hidden = self.camera_decoder(hidden, xpos=pos) + + with torch.amp.autocast(device_type="cuda", enabled=False): + # local points + point_hidden = point_hidden.float() + ret = self.point_head( + [point_hidden[:, self.patch_start_idx :]], (H, W) + ).reshape(B, N, H, W, -1) + xy, z = ret.split([2, 1], dim=-1) + z = torch.exp(z) + local_points = torch.cat([xy * z, z], dim=-1) + + # confidence + conf_hidden = conf_hidden.float() + conf = self.conf_head( + [conf_hidden[:, self.patch_start_idx :]], (H, W) + ).reshape(B, N, H, W, -1) + + # camera + camera_hidden = camera_hidden.float() + camera_poses = self.camera_head( + camera_hidden[:, self.patch_start_idx :], patch_h, patch_w + ).reshape(B, N, 4, 4) + + # unproject local points using camera poses + points = torch.einsum( + "bnij, bnhwj -> bnhwi", camera_poses, homogenize_points(local_points) + )[..., :3] + + return dict( + points=points, + local_points=local_points, + conf=conf, + camera_poses=camera_poses, + ) diff --git a/mapanything/models/external/pow3r/__init__.py b/mapanything/models/external/pow3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25d9c09f83bb37fda12c767bdb1f8f110463a792 --- /dev/null +++ b/mapanything/models/external/pow3r/__init__.py @@ -0,0 +1,865 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for Pow3R +""" + +import warnings +from copy import deepcopy + +import pow3r.model.blocks # noqa +import roma +import torch +import torch.nn as nn +import tqdm +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.image_pairs import make_pairs +from dust3r.inference import check_if_same_size +from dust3r.model import CroCoNet +from dust3r.patch_embed import get_patch_embed as dust3r_patch_embed +from dust3r.utils.device import collate_with_cat, to_cpu +from dust3r.utils.misc import ( + fill_default_args, + freeze_all_params, + interleave, + is_symmetrized, + transpose_to_landscape, +) +from pow3r.model.blocks import Block, BlockInject, DecoderBlock, DecoderBlockInject, Mlp +from pow3r.model.heads import head_factory +from pow3r.model.inference import ( + add_depth, + add_intrinsics, + add_relpose, +) +from pow3r.model.patch_embed import get_patch_embed + +from mapanything.models.external.vggt.utils.rotation import mat_to_quat +from mapanything.utils.geometry import ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + convert_z_depth_to_depth_along_ray, + depthmap_to_camera_frame, + get_rays_in_camera_frame, +) + + +class Pow3R(CroCoNet): + """Two siamese encoders, followed by two decoders. + The goal is to output 3d points directly, both images in view1's frame + (hence the asymmetry). + """ + + def __init__( + self, + mode="embed", + head_type="linear", + patch_embed_cls="PatchEmbedDust3R", + freeze="none", + landscape_only=True, + **croco_kwargs, + ): + # retrieve all default arguments using python magic + self.croco_args = fill_default_args(croco_kwargs, super().__init__) + super().__init__(**croco_kwargs) + del self.mask_token # useless + del self.prediction_head + + dec_dim, enc_dim = self.decoder_embed.weight.shape + self.enc_embed_dim = enc_dim + self.dec_embed_dim = dec_dim + + self.mode = mode + # additional parameters in the encoder + img_size = self.patch_embed.img_size + patch_size = self.patch_embed.patch_size[0] + self.patch_embed = dust3r_patch_embed( + patch_embed_cls, img_size, patch_size, self.enc_embed_dim + ) + self.patch_embed_rays = get_patch_embed( + patch_embed_cls + "_Mlp", + img_size, + patch_size, + self.enc_embed_dim, + in_chans=3, + ) + self.patch_embed_depth = get_patch_embed( + patch_embed_cls + "_Mlp", + img_size, + patch_size, + self.enc_embed_dim, + in_chans=2, + ) + self.pose_embed = Mlp(12, 4 * dec_dim, dec_dim) + + # additional parameters in the decoder + self.dec_cls = "_cls" in self.mode + self.dec_num_cls = 0 + if self.dec_cls: + # use a CLS token in the decoder only + self.mode = self.mode.replace("_cls", "") + self.cls_token1 = nn.Parameter(torch.zeros((dec_dim,))) + self.cls_token2 = nn.Parameter(torch.zeros((dec_dim,))) + self.dec_num_cls = 1 # affects all blocks + + use_ln = "_ln" in self.mode # TODO remove? + self.patch_ln = nn.LayerNorm(enc_dim) if use_ln else nn.Identity() + self.dec1_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity() + self.dec2_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity() + + self.dec_blocks2 = deepcopy(self.dec_blocks) + + # here we modify some of the blocks + self.replace_some_blocks() + + self.set_downstream_head(head_type, landscape_only, **croco_kwargs) + self.set_freeze(freeze) + + def replace_some_blocks(self): + assert self.mode.startswith("inject") # inject[0,0.5] + NewBlock = BlockInject + DecoderNewBlock = DecoderBlockInject + + all_layers = { + i / n + for i in range(len(self.enc_blocks)) + for n in [len(self.enc_blocks), len(self.dec_blocks)] + } + which_layers = eval(self.mode[self.mode.find("[") :]) or all_layers + assert isinstance(which_layers, (set, list)) + + n = 0 + for i, block in enumerate(self.enc_blocks): + if i / len(self.enc_blocks) in which_layers: + block.__class__ = NewBlock + block.init(self.enc_embed_dim) + n += 1 + else: + block.__class__ = Block + assert n == len(which_layers), breakpoint() + + n = 0 + for i in range(len(self.dec_blocks)): + for blocks in [self.dec_blocks, self.dec_blocks2]: + block = blocks[i] + if i / len(self.dec_blocks) in which_layers: + block.__class__ = DecoderNewBlock + block.init(self.dec_embed_dim) + n += 1 + else: + block.__class__ = DecoderBlock + assert n == 2 * len(which_layers), breakpoint() + + @classmethod + def from_pretrained(cls, pretrained_model_path, **kw): + return _load_model(pretrained_model_path, device="cpu") + + def load_state_dict(self, ckpt, **kw): + # duplicate all weights for the second decoder if not present + new_ckpt = dict(ckpt) + if not any(k.startswith("dec_blocks2") for k in ckpt): + for key, value in ckpt.items(): + if key.startswith("dec_blocks"): + new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value + # remove layers that have different shapes + cur_ckpt = self.state_dict() + for key, val in ckpt.items(): + if key.startswith("downstream_head2.proj"): + if key in cur_ckpt and cur_ckpt[key].shape != val.shape: + print(f" (removing ckpt[{key}] because wrong shape)") + del new_ckpt[key] + return super().load_state_dict(new_ckpt, **kw) + + def set_freeze(self, freeze): # this is for use by downstream models + self.freeze = freeze + to_be_frozen = { + "none": [], + "encoder": [self.patch_embed, self.enc_blocks], + } + freeze_all_params(to_be_frozen[freeze]) + + def set_prediction_head(self, *args, **kwargs): + """No prediction head""" + return + + def set_downstream_head( + self, + head_type, + landscape_only, + patch_size, + img_size, + mlp_ratio, + dec_depth, + **kw, + ): + assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, ( + f"{img_size=} must be multiple of {patch_size=}" + ) + + # split heads if different + heads = head_type.split(";") + assert len(heads) in (1, 2) + head1_type, head2_type = (heads + heads)[:2] + + # allocate heads + self.downstream_head1 = head_factory(head1_type, self) + self.downstream_head2 = head_factory(head2_type, self) + + # magic wrapper + self.head1 = transpose_to_landscape( + self.downstream_head1, activate=landscape_only + ) + self.head2 = transpose_to_landscape( + self.downstream_head2, activate=landscape_only + ) + + def _encode_image(self, image, true_shape, rays=None, depth=None): + # embed the image into patches (x has size B x Npatches x C) + x, pos = self.patch_embed(image, true_shape=true_shape) + + if rays is not None: # B,3,H,W + rays_emb, pos2 = self.patch_embed_rays(rays, true_shape=true_shape) + assert (pos == pos2).all() + if self.mode.startswith("embed"): + x = x + rays_emb + else: + rays_emb = None + + if depth is not None: # B,2,H,W + depth_emb, pos2 = self.patch_embed_depth(depth, true_shape=true_shape) + assert (pos == pos2).all() + if self.mode.startswith("embed"): + x = x + depth_emb + else: + depth_emb = None + + x = self.patch_ln(x) + + # add positional embedding without cls token + assert self.enc_pos_embed is None + + # now apply the transformer encoder and normalization + for blk in self.enc_blocks: + x = blk(x, pos, rays=rays_emb, depth=depth_emb) + + x = self.enc_norm(x) + return x, pos + + def encode_symmetrized(self, view1, view2): + img1 = view1["img"] + img2 = view2["img"] + B = img1.shape[0] + # Recover true_shape when available, otherwise assume that the img shape is the true one + shape1 = view1.get( + "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1) + ) + shape2 = view2.get( + "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1) + ) + # warning! maybe the images have different portrait/landscape orientations + + # privileged information + rays1 = view1.get("known_rays", None) + rays2 = view2.get("known_rays", None) + depth1 = view1.get("known_depth", None) + depth2 = view2.get("known_depth", None) + + if is_symmetrized(view1, view2): + # computing half of forward pass!' + def hsub(x): + return None if x is None else x[::2] + + feat1, pos1 = self._encode_image( + img1[::2], shape1[::2], rays=hsub(rays1), depth=hsub(depth1) + ) + feat2, pos2 = self._encode_image( + img2[::2], shape2[::2], rays=hsub(rays2), depth=hsub(depth2) + ) + + feat1, feat2 = interleave(feat1, feat2) + pos1, pos2 = interleave(pos1, pos2) + else: + feat1, pos1 = self._encode_image(img1, shape1, rays=rays1, depth=depth1) + feat2, pos2 = self._encode_image(img2, shape2, rays=rays2, depth=depth2) + + return (shape1, shape2), (feat1, feat2), (pos1, pos2) + + def _decoder(self, f1, pos1, f2, pos2, relpose1=None, relpose2=None): + final_output = [(f1, f2)] # before projection + + # project to decoder dim + f1 = self.decoder_embed(f1) + f2 = self.decoder_embed(f2) + + # add CLS token for the decoder + if self.dec_cls: + cls1 = self.cls_token1[None, None].expand(len(f1), 1, -1).clone() + cls2 = self.cls_token2[None, None].expand(len(f2), 1, -1).clone() + + if relpose1 is not None: # shape = (B, 4, 4) + pose_emb1 = self.pose_embed(relpose1[:, :3].flatten(1)).unsqueeze(1) + if self.mode.startswith("embed"): + if self.dec_cls: + cls1 = cls1 + pose_emb1 + else: + f1 = f1 + pose_emb1 + else: + pose_emb1 = None + + if relpose2 is not None: # shape = (B, 4, 4) + pose_emb2 = self.pose_embed(relpose2[:, :3].flatten(1)).unsqueeze(1) + if self.mode.startswith("embed"): + if self.dec_cls: + cls2 = cls2 + pose_emb2 + else: + f2 = f2 + pose_emb2 + else: + pose_emb2 = None + + if self.dec_cls: + f1, pos1 = cat_cls(cls1, f1, pos1) + f2, pos2 = cat_cls(cls2, f2, pos2) + + f1 = self.dec1_pre_ln(f1) + f2 = self.dec2_pre_ln(f2) + + final_output.append((f1, f2)) # to be removed later + for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): + # img1 side + f1, _ = blk1( + *final_output[-1][::+1], + pos1, + pos2, + relpose=pose_emb1, + num_cls=self.dec_num_cls, + ) + # img2 side + f2, _ = blk2( + *final_output[-1][::-1], + pos2, + pos1, + relpose=pose_emb2, + num_cls=self.dec_num_cls, + ) + # store the result + final_output.append((f1, f2)) + + del final_output[1] # duplicate with final_output[0] (after decoder proj) + if self.dec_cls: # remove cls token for decoder layers + final_output[1:] = [(f1[:, 1:], f2[:, 1:]) for f1, f2 in final_output[1:]] + # normalize last output + final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) + return zip(*final_output) + + def _downstream_head(self, head_num, decout, img_shape): + B, S, D = decout[-1].shape + head = getattr(self, f"head{head_num}") + return head(decout, img_shape) + + def forward(self, view1, view2): + # encode the two images --> B,S,D + (shape1, shape2), (feat1, feat2), (pos1, pos2) = self.encode_symmetrized( + view1, view2 + ) + + # combine all ref images into object-centric representation + dec1, dec2 = self._decoder( + feat1, + pos1, + feat2, + pos2, + relpose1=view1.get("known_pose"), + relpose2=view2.get("known_pose"), + ) + with torch.autocast("cuda", enabled=False): + res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) + res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) + + res2["pts3d_in_other_view"] = res2.pop( + "pts3d" + ) # predict view2's pts3d in view1's frame + return res1, res2 + + +def convert_release_dust3r_args(args): + args.model = ( + args.model.replace("patch_embed_cls", "patch_embed") + .replace("AsymmetricMASt3R", "AsymmetricCroCo3DStereo") + .replace("PatchEmbedDust3R", "convManyAR") + .replace( + "pos_embed='RoPE100'", + "enc_pos_embed='cuRoPE100', dec_pos_embed='cuRoPE100'", + ) + ) + return args + + +def _load_model(model_path, device): + print("... loading model from", model_path) + ckpt = torch.load(model_path, map_location="cpu") + try: + net = eval( + ckpt["args"].model[:-1].replace("convManyAR", "convP") + + ", landscape_only=False)" + ) + except Exception: + args = convert_release_dust3r_args(ckpt["args"]) + net = eval( + args.model[:-1].replace("convManyAR", "convP") + ", landscape_only=False)" + ) + ckpt["model"] = { + k.replace("_downstream_head", "downstream_head"): v + for k, v in ckpt["model"].items() + } + print(net.load_state_dict(ckpt["model"], strict=False)) + return net.to(device) + + +def cat_cls(cls, tokens, pos): + tokens = torch.cat((cls, tokens), dim=1) + pos = torch.cat((-pos.new_ones(len(cls), 1, 2), pos), dim=1) + return tokens, pos + + +class Pow3RWrapper(torch.nn.Module): + def __init__( + self, + name, + ckpt_path, + geometric_input_config, + **kwargs, + ): + super().__init__() + self.name = name + self.ckpt_path = ckpt_path + self.geometric_input_config = geometric_input_config + + # Init the model and load the checkpoint + print(f"Loading checkpoint from {self.ckpt_path} ...") + ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False) + model = ckpt["definition"] + print(f"Creating model = {model}") + self.model = eval(model) + print(self.model.load_state_dict(ckpt["weights"])) + + def forward(self, views): + """ + Forward pass wrapper for Pow3R. + + Assumption: + - The number of input views is 2. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Length of the list should be 2. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs: + "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3). + "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation. + "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1). + + Returns: + List[dict]: A list containing the final outputs for the two views. Length of the list will be 2. + """ + # Check that the number of input views is 2 + assert len(views) == 2, "Pow3R requires 2 input views." + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "Pow3R expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Get the batch size per view, device and two views + batch_size_per_view = views[0]["img"].shape[0] + device = views[0]["img"].device + view1, view2 = views + + # Decide if we need to use the geometric inputs + if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]: + # Decide if we need to use the camera intrinsics + if ( + torch.rand(1, device=device) + < self.geometric_input_config["ray_dirs_prob"] + ): + add_intrinsics(view1, view1.get("camera_intrinsics")) + add_intrinsics(view2, view2.get("camera_intrinsics")) + + # Decide if we need to use the depth map + if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]: + depthmap1 = view1.get("depthmap") + depthmap2 = view2.get("depthmap") + if depthmap1 is not None: + depthmap1 = depthmap1.squeeze(-1).to(device) + if depthmap2 is not None: + depthmap2 = depthmap2.squeeze(-1).to(device) + add_depth(view1, depthmap1) + add_depth(view2, depthmap2) + + # Decide if we need to use the camera pose + if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]: + cam1 = view1.get("camera_pose") + cam2 = view2.get("camera_pose") + add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1) + add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1) + + # Get the model predictions + preds = self.model(view1, view2) + + # Convert the output to MapAnything format + with torch.autocast("cuda", enabled=False): + res = [] + for view_idx in range(2): + # Get the model predictions for the current view + curr_view_pred = preds[view_idx] + + # For the first view + if view_idx == 0: + # Get the global frame and camera frame pointmaps + global_pts = curr_view_pred["pts3d"] + cam_pts = curr_view_pred["pts3d"] + conf = curr_view_pred["conf"] + + # Get the ray directions and depth along ray + depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True) + ray_directions = cam_pts / depth_along_ray + + # Initalize identity camera pose + cam_rot = torch.eye(3, device=device) + cam_quat = mat_to_quat(cam_rot) + cam_trans = torch.zeros(3, device=device) + cam_quat = cam_quat.unsqueeze(0).repeat(batch_size_per_view, 1) + cam_trans = cam_trans.unsqueeze(0).repeat(batch_size_per_view, 1) + # For the second view + elif view_idx == 1: + # Get the global frame and camera frame pointmaps + pred_global_pts = curr_view_pred["pts3d_in_other_view"] + cam_pts = curr_view_pred["pts3d2"] + conf = (curr_view_pred["conf"] * curr_view_pred["conf2"]).sqrt() + + # Get the ray directions and depth along ray + depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True) + ray_directions = cam_pts / depth_along_ray + + # Compute the camera pose using the pointmaps + cam_rot, cam_trans, scale = roma.rigid_points_registration( + cam_pts.reshape(batch_size_per_view, -1, 3), + pred_global_pts.reshape(batch_size_per_view, -1, 3), + weights=conf.reshape(batch_size_per_view, -1), + compute_scaling=True, + ) + cam_quat = mat_to_quat(cam_rot) + + # Scale the predicted camera frame pointmap and compute the new global frame pointmap + cam_pts = scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * cam_pts + global_pts = cam_pts.reshape( + batch_size_per_view, -1, 3 + ) @ cam_rot.permute(0, 2, 1) + cam_trans.unsqueeze(1) + global_pts = global_pts.view(pred_global_pts.shape) + + # Append the result in MapAnything format + res.append( + { + "pts3d": global_pts, + "pts3d_cam": cam_pts, + "ray_directions": ray_directions, + "depth_along_ray": depth_along_ray, + "cam_trans": cam_trans, + "cam_quats": cam_quat, + "conf": conf, + } + ) + + return res + + +class Pow3RBAWrapper(torch.nn.Module): + def __init__( + self, + name, + ckpt_path, + geometric_input_config, + scene_graph="complete", + inference_batch_size=32, + global_optim_schedule="cosine", + global_optim_lr=0.01, + global_optim_niter=300, + **kwargs, + ): + super().__init__() + self.name = name + self.ckpt_path = ckpt_path + self.geometric_input_config = geometric_input_config + self.scene_graph = scene_graph + self.inference_batch_size = inference_batch_size + self.global_optim_schedule = global_optim_schedule + self.global_optim_lr = global_optim_lr + self.global_optim_niter = global_optim_niter + + # Init the model and load the checkpoint + print(f"Loading checkpoint from {self.ckpt_path} ...") + ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False) + model = ckpt["definition"] + print(f"Creating model = {model}") + self.model = eval(model) + print(self.model.load_state_dict(ckpt["weights"])) + + # Init the global aligner mode + self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer + + def infer_two_views(self, views): + """ + Wrapper for Pow3R 2-View inference. + + Assumption: + - The number of input views is 2. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Length of the list should be 2. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs: + "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3). + "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation. + "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1). + + Returns: + List[dict]: A list containing the final outputs for the two views. Length of the list will be 2. + """ + # Check that the number of input views is 2 + assert len(views) == 2, "Pow3R requires 2 input views." + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "Pow3R expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Get the device and two views + device = views[0]["img"].device + view1, view2 = views + + # Decide if we need to use the geometric inputs + if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]: + # Decide if we need to use the camera intrinsics + if ( + torch.rand(1, device=device) + < self.geometric_input_config["ray_dirs_prob"] + ): + add_intrinsics(view1, view1.get("camera_intrinsics")) + add_intrinsics(view2, view2.get("camera_intrinsics")) + + # Decide if we need to use the depth map + if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]: + depthmap1 = view1.get("depthmap") + depthmap2 = view2.get("depthmap") + if depthmap1 is not None: + depthmap1 = depthmap1.squeeze(-1).to(device) + if depthmap2 is not None: + depthmap2 = depthmap2.squeeze(-1).to(device) + add_depth(view1, depthmap1) + add_depth(view2, depthmap2) + + # Decide if we need to use the camera pose + if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]: + cam1 = view1.get("camera_pose") + cam2 = view2.get("camera_pose") + add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1) + add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1) + + # Get the model predictions + preds = self.model(view1, view2) + + return preds + + def loss_of_one_batch(self, batch, device): + """ + Compute prediction for two views. + """ + view1, view2 = batch + ignore_keys = set( + [ + "dataset", + "label", + "instance", + "idx", + "true_shape", + "rng", + "name", + "data_norm_type", + ] + ) + for view in batch: + for name in view.keys(): # pseudo_focal + if name in ignore_keys: + continue + view[name] = view[name].to(device, non_blocking=True) + + pred1, pred2 = self.infer_two_views([view1, view2]) + + result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2) + + return result + + @torch.no_grad() + def inference(self, pairs, device, verbose=False): + """ + Wrapper for multi-pair inference using Pow3R. + """ + if verbose: + print(f">> Inference with model on {len(pairs)} image pairs") + result = [] + + multiple_shapes = not (check_if_same_size(pairs)) + if multiple_shapes: + self.inference_batch_size = 1 + + for i in tqdm.trange( + 0, len(pairs), self.inference_batch_size, disable=not verbose + ): + res = self.loss_of_one_batch( + collate_with_cat(pairs[i : i + self.inference_batch_size]), device + ) + result.append(to_cpu(res)) + + result = collate_with_cat(result, lists=multiple_shapes) + + return result + + def forward(self, views): + """ + Forward pass wrapper for Pow3R using the global aligner. + + Assumption: + - The batch size of input views is 1. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys, where B is the batch size and is 1: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["dust3r"] + + Returns: + List[dict]: A list containing the final outputs for the input views. + """ + # Check the batch size of input views + batch_size_per_view, _, height, width = views[0]["img"].shape + device = views[0]["img"].device + num_views = len(views) + assert batch_size_per_view == 1, ( + f"Batch size of input views should be 1, but got {batch_size_per_view}." + ) + + # Check the data norm type + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "dust3r", ( + "Pow3R-BA expects a normalized image with the DUSt3R normalization scheme applied" + ) + + # Convert the input views to the expected input format + images = [] + for view in views: + images.append( + dict( + img=view["img"], + camera_intrinsics=view["camera_intrinsics"], + depthmap=view["depthmap"], + camera_pose=view["camera_pose"], + data_norm_type=view["data_norm_type"], + true_shape=view["true_shape"], + idx=len(images), + instance=str(len(images)), + ) + ) + + # Make image pairs and run inference pair-wise + pairs = make_pairs( + images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + output = self.inference( + pairs, + device, + verbose=False, + ) + + # Global optimization + with torch.enable_grad(): + scene = global_aligner( + output, device=device, mode=self.global_aligner_mode, verbose=False + ) + _ = scene.compute_global_alignment( + init="mst", + niter=self.global_optim_niter, + schedule=self.global_optim_schedule, + lr=self.global_optim_lr, + ) + + # Make sure scene is not None + if scene is None: + raise RuntimeError("Global optimization failed.") + + # Get the predictions + intrinsics = scene.get_intrinsics() + c2w_poses = scene.get_im_poses() + depths = scene.get_depthmaps() + + # Convert the output to the MapAnything format + with torch.autocast("cuda", enabled=False): + res = [] + for view_idx in range(num_views): + # Get the current view predictions + curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0) + curr_view_pose = c2w_poses[view_idx].unsqueeze(0) + curr_view_depth_z = depths[view_idx].unsqueeze(0) + + # Convert the pose to quaternions and translation + curr_view_cam_translations = curr_view_pose[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) + + # Get the camera frame pointmaps + curr_view_pts3d_cam, _ = depthmap_to_camera_frame( + curr_view_depth_z, curr_view_intrinsic + ) + + # Convert the z depth to depth along ray + curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( + curr_view_depth_z, curr_view_intrinsic + ) + curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) + + # Get the ray directions on the unit sphere in the camera frame + _, curr_view_ray_dirs = get_rays_in_camera_frame( + curr_view_intrinsic, height, width, normalize_to_unit_sphere=True + ) + + # Get the pointmaps + curr_view_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + curr_view_ray_dirs, + curr_view_depth_along_ray, + curr_view_cam_translations, + curr_view_cam_quats, + ) + ) + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d, + "pts3d_cam": curr_view_pts3d_cam, + "ray_directions": curr_view_ray_dirs, + "depth_along_ray": curr_view_depth_along_ray, + "cam_trans": curr_view_cam_translations, + "cam_quats": curr_view_cam_quats, + } + ) + + return res diff --git a/mapanything/models/external/vggt/__init__.py b/mapanything/models/external/vggt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a182f2e6af72f1fb5667fce91ef83c86cb1fb0cf --- /dev/null +++ b/mapanything/models/external/vggt/__init__.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Inference wrapper for VGGT +""" + +import torch + +from mapanything.models.external.vggt.models.vggt import VGGT +from mapanything.models.external.vggt.utils.geometry import closed_form_inverse_se3 +from mapanything.models.external.vggt.utils.pose_enc import pose_encoding_to_extri_intri +from mapanything.models.external.vggt.utils.rotation import mat_to_quat +from mapanything.utils.geometry import ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + convert_z_depth_to_depth_along_ray, + depthmap_to_camera_frame, + get_rays_in_camera_frame, +) + + +class VGGTWrapper(torch.nn.Module): + def __init__( + self, + name, + torch_hub_force_reload, + load_pretrained_weights=True, + depth=24, + num_heads=16, + intermediate_layer_idx=[4, 11, 17, 23], + load_custom_ckpt=False, + custom_ckpt_path=None, + ): + super().__init__() + self.name = name + self.torch_hub_force_reload = torch_hub_force_reload + self.load_custom_ckpt = load_custom_ckpt + self.custom_ckpt_path = custom_ckpt_path + + if load_pretrained_weights: + # Load pre-trained weights + if not torch_hub_force_reload: + # Initialize the 1B VGGT model from huggingface hub cache + print("Loading facebook/VGGT-1B from huggingface cache ...") + self.model = VGGT.from_pretrained( + "facebook/VGGT-1B", + ) + else: + # Initialize the 1B VGGT model + print("Re-downloading facebook/VGGT-1B ...") + self.model = VGGT.from_pretrained( + "facebook/VGGT-1B", force_download=True + ) + else: + # Load the VGGT class + self.model = VGGT( + depth=depth, + num_heads=num_heads, + intermediate_layer_idx=intermediate_layer_idx, + ) + + # Get the dtype for VGGT inference + # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) + self.dtype = ( + torch.bfloat16 + if torch.cuda.get_device_capability()[0] >= 8 + else torch.float16 + ) + + # Load custom checkpoint if requested + if self.load_custom_ckpt: + print(f"Loading checkpoint from {self.custom_ckpt_path} ...") + assert self.custom_ckpt_path is not None, ( + "custom_ckpt_path must be provided if load_custom_ckpt is set to True" + ) + custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False) + print(self.model.load_state_dict(custom_ckpt, strict=True)) + del custom_ckpt # in case it occupies memory + + def forward(self, views): + """ + Forward pass wrapper for VGGT + + Assumption: + - All the input views have the same image shape. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). + "data_norm_type" (list): ["identity"] + + Returns: + List[dict]: A list containing the final outputs for all N views. + """ + # Get input shape of the images, number of views, and batch size per view + batch_size_per_view, _, height, width = views[0]["img"].shape + num_views = len(views) + + # Check the data norm type + # VGGT expects a normalized image but without the DINOv2 mean and std applied ("identity") + data_norm_type = views[0]["data_norm_type"][0] + assert data_norm_type == "identity", ( + "VGGT expects a normalized image but without the DINOv2 mean and std applied" + ) + + # Concatenate the images to create a single (B, V, C, H, W) tensor + img_list = [view["img"] for view in views] + images = torch.stack(img_list, dim=1) + + # Run the VGGT aggregator + with torch.autocast("cuda", dtype=self.dtype): + aggregated_tokens_list, ps_idx = self.model.aggregator(images) + + # Run the Camera + Pose Branch of VGGT + with torch.autocast("cuda", enabled=False): + # Predict Cameras + pose_enc = self.model.camera_head(aggregated_tokens_list)[-1] + # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world) + # Extrinsics Shape: (B, V, 3, 4) + # Intrinsics Shape: (B, V, 3, 3) + extrinsic, intrinsic = pose_encoding_to_extri_intri( + pose_enc, images.shape[-2:] + ) + + # Predict Depth Maps + # Depth Shape: (B, V, H, W, 1) + # Depth Confidence Shape: (B, V, H, W) + depth_map, depth_conf = self.model.depth_head( + aggregated_tokens_list, images, ps_idx + ) + + # Convert the output to MapAnything format + res = [] + for view_idx in range(num_views): + # Get the extrinsics, intrinsics, depth map for the current view + curr_view_extrinsic = extrinsic[:, view_idx, ...] + curr_view_extrinsic = closed_form_inverse_se3( + curr_view_extrinsic + ) # Convert to cam2world + curr_view_intrinsic = intrinsic[:, view_idx, ...] + curr_view_depth_z = depth_map[:, view_idx, ...] + curr_view_depth_z = curr_view_depth_z.squeeze(-1) + curr_view_confidence = depth_conf[:, view_idx, ...] + + # Get the camera frame pointmaps + curr_view_pts3d_cam, _ = depthmap_to_camera_frame( + curr_view_depth_z, curr_view_intrinsic + ) + + # Convert the extrinsics to quaternions and translations + curr_view_cam_translations = curr_view_extrinsic[..., :3, 3] + curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3]) + + # Convert the z depth to depth along ray + curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( + curr_view_depth_z, curr_view_intrinsic + ) + curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) + + # Get the ray directions on the unit sphere in the camera frame + _, curr_view_ray_dirs = get_rays_in_camera_frame( + curr_view_intrinsic, height, width, normalize_to_unit_sphere=True + ) + + # Get the pointmaps + curr_view_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + curr_view_ray_dirs, + curr_view_depth_along_ray, + curr_view_cam_translations, + curr_view_cam_quats, + ) + ) + + # Append the outputs to the result list + res.append( + { + "pts3d": curr_view_pts3d, + "pts3d_cam": curr_view_pts3d_cam, + "ray_directions": curr_view_ray_dirs, + "depth_along_ray": curr_view_depth_along_ray, + "cam_trans": curr_view_cam_translations, + "cam_quats": curr_view_cam_quats, + "conf": curr_view_confidence, + } + ) + + return res diff --git a/mapanything/models/external/vggt/heads/__init__.py b/mapanything/models/external/vggt/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/vggt/heads/camera_head.py b/mapanything/models/external/vggt/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..17efa894a60d8842ff6c5e789c3f95fc6331c791 --- /dev/null +++ b/mapanything/models/external/vggt/heads/camera_head.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn + +from mapanything.models.external.vggt.heads.head_act import activate_pose +from mapanything.models.external.vggt.layers import Mlp +from mapanything.models.external.vggt.layers.block import Block + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block( + dim=dim_in, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=init_values, + ) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True) + ) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp( + in_features=dim_in, + hidden_features=dim_in // 2, + out_features=self.target_dim, + drop=0, + ) + + def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape # S is expected to be 1. + pred_pose_enc = None + pred_pose_enc_list = [] + + for _ in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk( + 3, dim=-1 + ) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate( + self.adaln_norm(pose_tokens), shift_msa, scale_msa + ) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + pose_tokens_modulated = self.trunk(pose_tokens_modulated) + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch( + self.trunk_norm(pose_tokens_modulated) + ) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, + trans_act=self.trans_act, + quat_act=self.quat_act, + fl_act=self.fl_act, + ) + pred_pose_enc_list.append(activated_pose) + + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift diff --git a/mapanything/models/external/vggt/heads/dpt_head.py b/mapanything/models/external/vggt/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7c138c67c73cbadf615a96ff13bc1d7a09dac980 --- /dev/null +++ b/mapanything/models/external/vggt/heads/dpt_head.py @@ -0,0 +1,600 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 + + +from typing import List, Tuple, Union + +import torch +import torch.nn as nn + +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + DPT Head for dense prediction tasks. + + This implementation follows the architecture described in "Vision Transformers for Dense Prediction" + (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer + backbone and produces dense predictions by fusing multi-scale features. + + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=oc, + kernel_size=1, + stride=1, + padding=0, + ) + for oc in out_channels + ] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0, + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1, + ), + ] + ) + + self.scratch = _make_scratch( + out_channels, + features, + expand=False, + ) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1, kernel_size=3, stride=1, padding=1 + ) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, + head_features_1 // 2, + kernel_size=3, + stride=1, + padding=1, + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d( + conv2_in_channels, + head_features_2, + kernel_size=3, + stride=1, + padding=1, + ), + nn.ReLU(inplace=True), + nn.Conv2d( + head_features_2, output_dim, kernel_size=1, stride=1, padding=0 + ), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = images.shape + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, + images, + patch_start_idx, + frames_start_idx, + frames_end_idx, + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, + images, + patch_start_idx, + frames_start_idx, + frames_end_idx, + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Implementation of the forward pass through the DPT head. + + This method processes a specific chunk of frames from the sequence. + + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx].contiguous() + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.reshape(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + ( + int(patch_h * self.patch_size / self.down_ratio), + int(patch_w * self.patch_size / self.down_ratio), + ), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.view(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head( + out, activation=self.activation, conf_activation=self.conf_activation + ) + + preds = preds.view(B, S, *preds.shape[1:]) + conf = conf.view(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed( + self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1 + ) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid( + patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device + ) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +################################################################################ +# Modules +################################################################################ + + +def _make_fusion_block( + features: int, size: int = None, has_residual: bool = True, groups: int = 1 +) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch( + in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False +) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=self.groups, + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit( + features, activation, bn, groups=self.groups + ) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit( + features, activation, bn, groups=self.groups + ) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate( + chunk, size=size, mode=mode, align_corners=align_corners + ) + for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate( + x, size=size, mode=mode, align_corners=align_corners + ) diff --git a/mapanything/models/external/vggt/heads/head_act.py b/mapanything/models/external/vggt/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..acf073f9a59d5901422a342ea61372591f90fe57 --- /dev/null +++ b/mapanything/models/external/vggt/heads/head_act.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def activate_pose( + pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear" +): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/mapanything/models/external/vggt/heads/track_head.py b/mapanything/models/external/vggt/heads/track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3464b871a7b7ad3976daf51f4d08022bd6d71ac8 --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_head.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from .dpt_head import DPTHead +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackHead(nn.Module): + """ + Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. + The tracking is performed iteratively, refining predictions over multiple iterations. + """ + + def __init__( + self, + dim_in, + patch_size=14, + features=128, + iters=4, + predict_conf=True, + stride=2, + corr_levels=7, + corr_radius=4, + hidden_size=384, + ): + """ + Initialize the TrackHead module. + + Args: + dim_in (int): Input dimension of tokens from the backbone. + patch_size (int): Size of image patches used in the vision transformer. + features (int): Number of feature channels in the feature extractor output. + iters (int): Number of refinement iterations for tracking predictions. + predict_conf (bool): Whether to predict confidence scores for tracked points. + stride (int): Stride value for the tracker predictor. + corr_levels (int): Number of correlation pyramid levels + corr_radius (int): Radius for correlation computation, controlling the search area. + hidden_size (int): Size of hidden layers in the tracker network. + """ + super().__init__() + + self.patch_size = patch_size + + # Feature extractor based on DPT architecture + # Processes tokens into feature maps for tracking + self.feature_extractor = DPTHead( + dim_in=dim_in, + patch_size=patch_size, + features=features, + feature_only=True, # Only output features, no activation + down_ratio=2, # Reduces spatial dimensions by factor of 2 + pos_embed=False, + ) + + # Tracker module that predicts point trajectories + # Takes feature maps and predicts coordinates and visibility + self.tracker = BaseTrackerPredictor( + latent_dim=features, # Match the output_dim of feature extractor + predict_conf=predict_conf, + stride=stride, + corr_levels=corr_levels, + corr_radius=corr_radius, + hidden_size=hidden_size, + ) + + self.iters = iters + + def forward( + self, + aggregated_tokens_list, + images, + patch_start_idx, + query_points=None, + iters=None, + ): + """ + Forward pass of the TrackHead. + + Args: + aggregated_tokens_list (list): List of aggregated tokens from the backbone. + images (torch.Tensor): Input images of shape (B, S, C, H, W) where: + B = batch size, S = sequence length. + patch_start_idx (int): Starting index for patch tokens. + query_points (torch.Tensor, optional): Initial query points to track. + If None, points are initialized by the tracker. + iters (int, optional): Number of refinement iterations. If None, uses self.iters. + + Returns: + tuple: + - coord_preds (torch.Tensor): Predicted coordinates for tracked points. + - vis_scores (torch.Tensor): Visibility scores for tracked points. + - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). + """ + B, S, _, H, W = images.shape + + # Extract features from tokens + # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 + feature_maps = self.feature_extractor( + aggregated_tokens_list, images, patch_start_idx + ) + + # Use default iterations if not specified + if iters is None: + iters = self.iters + + # Perform tracking using the extracted features + coord_preds, vis_scores, conf_scores = self.tracker( + query_points=query_points, + fmaps=feature_maps, + iters=iters, + ) + + return coord_preds, vis_scores, conf_scores diff --git a/mapanything/models/external/vggt/heads/track_modules/__init__.py b/mapanything/models/external/vggt/heads/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4196294309799347172dba54a17360698071ca8 --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py b/mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..204a82ac099c7cda89c2cc66e4290b2e3da33542 --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange + +from .blocks import CorrBlock, EfficientUpdateFormer +from .modules import Mlp +from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=1, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + max_scale=518, + predict_conf=True, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + and https://github.com/facebookresearch/vggsfm + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.max_scale = max_scale + self.predict_conf = predict_conf + + self.flows_emb_dim = latent_dim // 2 + + self.corr_mlp = Mlp( + in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, + hidden_features=self.hidden_size, + out_features=self.latent_dim, + ) + + self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 + + self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.fmap_norm = nn.LayerNorm(self.latent_dim) + self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), nn.GELU() + ) + + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + if predict_conf: + self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward( + self, + query_points, + fmaps=None, + iters=6, + return_feat=False, + down_ratio=1, + apply_sigmoid=True, + ): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2, "Input points must be 2D coordinates" + + # apply a layernorm to fmaps here + fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) + fmaps = fmaps.permute(0, 1, 4, 2, 3) + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + fcorr_fn = CorrBlock( + fmaps, num_levels=self.corr_levels, radius=self.corr_radius + ) + + coord_preds = [] + + # Iterative Refinement + for _ in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + fcorrs = fcorr_fn.corr_sample(track_feats, coords) + + corr_dim = fcorrs.shape[3] + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) + fcorrs_ = self.corr_mlp(fcorrs_) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat( + [flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1 + ) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape( + B * N, S, self.latent_dim + ) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed( + self.transformer_dim, grid_size=(HH, WW) + ).to(query_points.device) + sampled_pos_emb = sample_features4d( + pos_embed.expand(B, -1, -1, -1), coords[:, 0] + ) + + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze( + 1 + ) + + x = transformer_input + sampled_pos_emb + + # Add the query ref token to the track feats + query_ref_token = torch.cat( + [ + self.query_ref_token[:, 0:1], + self.query_ref_token[:, 1:2].expand(-1, S - 1, -1), + ], + dim=1, + ) + x = x + query_ref_token.to(x.device).to(x.dtype) + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta, _ = self.updateformer(x) + + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = ( + self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ + ) + + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute( + 0, 2, 1, 3 + ) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + vis_e = self.vis_predictor( + track_feats.reshape(B * S * N, self.latent_dim) + ).reshape(B, S, N) + if apply_sigmoid: + vis_e = torch.sigmoid(vis_e) + + if self.predict_conf: + conf_e = self.conf_predictor( + track_feats.reshape(B * S * N, self.latent_dim) + ).reshape(B, S, N) + if apply_sigmoid: + conf_e = torch.sigmoid(conf_e) + else: + conf_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat, conf_e + else: + return coord_preds, vis_e, conf_e diff --git a/mapanything/models/external/vggt/heads/track_modules/blocks.py b/mapanything/models/external/vggt/heads/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..4285ce15bac49153ead21fb9ab5ad50c5bf555ac --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_modules/blocks.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modules import AttnBlock, CrossAttnBlock +from .utils import bilinear_sampler + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + + # Add input LayerNorm before linear projection + self.input_norm = nn.LayerNorm(input_dim) + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + + # Add output LayerNorm before final projection + self.output_norm = nn.LayerNorm(hidden_size) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter( + torch.randn(1, num_virtual_tracks, 1, hidden_size) + ) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [ + CrossAttnBlock( + hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio + ) + for _ in range(space_depth) + ] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [ + CrossAttnBlock( + hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio + ) + for _ in range(space_depth) + ] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) + + self.apply(_basic_init) + + def forward(self, input_tensor, mask=None): + # Apply input LayerNorm + input_tensor = self.input_norm(input_tensor) + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and ( + i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0 + ): + space_tokens = ( + tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) + ) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j]( + virtual_tokens, point_tokens, mask=mask + ) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j]( + point_tokens, virtual_tokens, mask=mask + ) + + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute( + 0, 2, 1, 3 + ) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + # Apply output LayerNorm before final projection + tokens = self.output_norm(tokens) + flow = self.flow_head(tokens) + + return flow, None + + +class CorrBlock: + def __init__( + self, + fmaps, + num_levels=4, + radius=4, + multiple_track_feats=False, + padding_mode="zeros", + ): + """ + Build a pyramid of feature maps from the input. + + fmaps: Tensor (B, S, C, H, W) + num_levels: number of pyramid levels (each downsampled by factor 2) + radius: search radius for sampling correlation + multiple_track_feats: if True, split the target features per pyramid level + padding_mode: passed to grid_sample / bilinear_sampler + """ + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.num_levels = num_levels + self.radius = radius + self.padding_mode = padding_mode + self.multiple_track_feats = multiple_track_feats + + # Build pyramid: each level is half the spatial resolution of the previous + self.fmaps_pyramid = [fmaps] # level 0 is full resolution + current_fmaps = fmaps + for i in range(num_levels - 1): + B, S, C, H, W = current_fmaps.shape + # Merge batch & sequence dimensions + current_fmaps = current_fmaps.reshape(B * S, C, H, W) + # Avg pool down by factor 2 + current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) + _, _, H_new, W_new = current_fmaps.shape + current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) + self.fmaps_pyramid.append(current_fmaps) + + # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. + # This grid is added to the (scaled) coordinate centroids. + r = self.radius + dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + # delta: for every (dy,dx) displacement (i.e. Δx, Δy) + self.delta = torch.stack( + torch.meshgrid(dy, dx, indexing="ij"), dim=-1 + ) # shape: (2r+1, 2r+1, 2) + + def corr_sample(self, targets, coords): + """ + Instead of storing the entire correlation pyramid, we compute each level's correlation + volume, sample it immediately, then discard it. This saves GPU memory. + + Args: + targets: Tensor (B, S, N, C) — features for the current targets. + coords: Tensor (B, S, N, 2) — coordinates at full resolution. + + Returns: + Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) + """ + B, S, N, C = targets.shape + + # If you have multiple track features, split them per level. + if self.multiple_track_feats: + targets_split = torch.split(targets, C // self.num_levels, dim=-1) + + out_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + # Get current spatial resolution H, W for this pyramid level. + B, S, C, H, W = fmaps.shape + # Reshape feature maps for correlation computation: + # fmap2s: (B, S, C, H*W) + fmap2s = fmaps.view(B, S, C, H * W) + # Choose appropriate target features. + fmap1 = ( + targets_split[i] if self.multiple_track_feats else targets + ) # shape: (B, S, N, C) + + # Compute correlation directly + corrs = compute_corr_level(fmap1, fmap2s, C) + corrs = corrs.view(B, S, N, H, W) + + # Prepare sampling grid: + # Scale down the coordinates for the current level. + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) + # Make sure our precomputed delta grid is on the same device/dtype. + delta_lvl = self.delta.to(coords.device).to(coords.dtype) + # Now the grid for grid_sample is: + # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) + coords_lvl = centroid_lvl + delta_lvl.view( + 1, 2 * self.radius + 1, 2 * self.radius + 1, 2 + ) + + # Sample from the correlation volume using bilinear interpolation. + # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. + corrs_sampled = bilinear_sampler( + corrs.reshape(B * S * N, 1, H, W), + coords_lvl, + padding_mode=self.padding_mode, + ) + # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. + corrs_sampled = corrs_sampled.view( + B, S, N, -1 + ) # Now shape: (B, S, N, (2r+1)^2) + out_pyramid.append(corrs_sampled) + + # Concatenate all levels along the last dimension. + out = torch.cat(out_pyramid, dim=-1).contiguous() + return out + + +def compute_corr_level(fmap1, fmap2s, C): + # fmap1: (B, S, N, C) + # fmap2s: (B, S, C, H*W) + corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) + corrs = corrs.view( + fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1 + ) # (B, S, N, H*W) + return corrs / math.sqrt(C) diff --git a/mapanything/models/external/vggt/heads/track_modules/modules.py b/mapanything/models/external/vggt/heads/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..dbfb3825f8bafc8e33e2d33c3d3b390186d59c86 --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_modules/modules.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import collections +from functools import partial +from itertools import repeat +from typing import Callable + +import torch.nn as nn + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=kernel_size, + padding=1, + stride=stride, + padding_mode="zeros", + ) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=kernel_size, + padding=1, + padding_mode="zeros", + ) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), + self.norm3, + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.attn = attn_class( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__( + self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs + ): + """ + Cross attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/mapanything/models/external/vggt/heads/track_modules/utils.py b/mapanything/models/external/vggt/heads/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcb15170d6a1639004ed4a7bda8c48e9430ef1b --- /dev/null +++ b/mapanything/models/external/vggt/heads/track_modules/utils.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/vggsfm +# and https://github.com/facebookresearch/co-tracker/tree/main + + +from typing import Tuple, Union + +import torch +import torch.nn.functional as F + + +def get_2d_sincos_pos_embed( + embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False +) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return ( + pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), + grid, + ) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: torch.Tensor +) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: torch.Tensor +) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = ( + torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) + ).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + coords = coords.detach().clone() + ############################################################ + # IMPORTANT: + coords = coords.to(input.device).to(input.dtype) + ############################################################ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + scale = torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], + device=coords.device, + dtype=coords.dtype, + ) + else: + scale = torch.tensor( + [2 / size for size in reversed(sizes)], + device=coords.device, + dtype=coords.dtype, + ) + + coords.mul_(scale) # coords = coords * scale + coords.sub_(1) # coords = coords - 1 + + return F.grid_sample( + input, coords, align_corners=align_corners, padding_mode=padding_mode + ) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view( + B, -1, feats.shape[1] * feats.shape[3] + ) # B C R 1 -> B R C diff --git a/mapanything/models/external/vggt/heads/utils.py b/mapanything/models/external/vggt/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c23a6e137f69c262a3e24f438b849c901d21158a --- /dev/null +++ b/mapanything/models/external/vggt/heads/utils.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def position_grid_to_embed( + pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100 +) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed( + embed_dim // 2, pos_flat[:, 0], omega_0=omega_0 + ) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed( + embed_dim // 2, pos_flat[:, 1], omega_0=omega_0 + ) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed( + embed_dim: int, pos: torch.Tensor, omega_0: float = 100 +) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + device = pos.device + omega = torch.arange( + embed_dim // 2, + dtype=torch.float32 if device.type == "mps" else torch.double, + device=device, + ) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, + height: int, + aspect_ratio: float = None, + dtype: torch.dtype = None, + device: torch.device = None, +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid diff --git a/mapanything/models/external/vggt/layers/__init__.py b/mapanything/models/external/vggt/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b84d2d861eb72cdb1ebc5e7c001938df0b851486 --- /dev/null +++ b/mapanything/models/external/vggt/layers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed + +__all__ = [ + "Mlp", + "PatchEmbed", +] diff --git a/mapanything/models/external/vggt/layers/attention.py b/mapanything/models/external/vggt/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb536194c0bf49b744ca9c819655bf9127ad0ed --- /dev/null +++ b/mapanything/models/external/vggt/layers/attention.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + + +import torch.nn.functional as F +from torch import nn, Tensor + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +# class MemEffAttention(Attention): +# def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: +# assert pos is None +# if not XFORMERS_AVAILABLE: +# if attn_bias is not None: +# raise AssertionError("xFormers is required for using nested tensors") +# return super().forward(x) + +# B, N, C = x.shape +# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + +# q, k, v = unbind(qkv, 2) + +# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) +# x = x.reshape([B, N, C]) + +# x = self.proj(x) +# x = self.proj_drop(x) +# return x diff --git a/mapanything/models/external/vggt/layers/block.py b/mapanything/models/external/vggt/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6d562dfa009034081062216d4d98b0662c024e --- /dev/null +++ b/mapanything/models/external/vggt/layers/block.py @@ -0,0 +1,280 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Any, Callable, Dict, List, Tuple + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + +XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + pos=pos, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, pos=pos) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, + pos=None, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + pass + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + # attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + # attn_bias._batch_sizes = batch_sizes + # attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + pass + # cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + # 1, -1, x_list[0].shape[-1] + # ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + # assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/mapanything/models/external/vggt/layers/drop_path.py b/mapanything/models/external/vggt/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..04cb47af065ec3cfdbc8f59854efaa976ea717d5 --- /dev/null +++ b/mapanything/models/external/vggt/layers/drop_path.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/mapanything/models/external/vggt/layers/layer_scale.py b/mapanything/models/external/vggt/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..b32da3bd74a028795f5ee628d2636a45b756b074 --- /dev/null +++ b/mapanything/models/external/vggt/layers/layer_scale.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import nn, Tensor + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/mapanything/models/external/vggt/layers/mlp.py b/mapanything/models/external/vggt/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6d19f53d8562d07d559fe6db93d45b8286420720 --- /dev/null +++ b/mapanything/models/external/vggt/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import nn, Tensor + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/mapanything/models/external/vggt/layers/patch_embed.py b/mapanything/models/external/vggt/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..493774d038c9ee7f0f63b05f80561fa61321e2b7 --- /dev/null +++ b/mapanything/models/external/vggt/layers/patch_embed.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +import torch.nn as nn +from torch import Tensor + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, ( + f"Input image height {H} is not a multiple of patch height {patch_H}" + ) + assert W % patch_W == 0, ( + f"Input image width {W} is not a multiple of patch width: {patch_W}" + ) + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/mapanything/models/external/vggt/layers/rope.py b/mapanything/models/external/vggt/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd3594bfd736752d5add0edd1986810578b253f --- /dev/null +++ b/mapanything/models/external/vggt/layers/rope.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +from typing import Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__( + self, batch_size: int, height: int, width: int, device: torch.device + ) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return ( + cached_positions.view(1, height * width, 2) + .expand(batch_size, -1, -1) + .clone() + ) + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, + tokens: torch.Tensor, + positions: torch.Tensor, + cos_comp: torch.Tensor, + sin_comp: torch.Tensor, + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert positions.ndim == 3 and positions.shape[-1] == 2, ( + "Positions must have shape (batch_size, n_tokens, 2)" + ) + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components( + feature_dim, max_position, tokens.device, tokens.dtype + ) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope( + vertical_features, positions[..., 0], cos_comp, sin_comp + ) + horizontal_features = self._apply_1d_rope( + horizontal_features, positions[..., 1], cos_comp, sin_comp + ) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/mapanything/models/external/vggt/layers/swiglu_ffn.py b/mapanything/models/external/vggt/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..9195a0ed27d4f06b0c0ba97c8a9bd7cc3214642e --- /dev/null +++ b/mapanything/models/external/vggt/layers/swiglu_ffn.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional + +import torch.nn.functional as F +from torch import nn, Tensor + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +# try: +# if XFORMERS_ENABLED: +# from xformers.ops import SwiGLU + +# XFORMERS_AVAILABLE = True +# warnings.warn("xFormers is available (SwiGLU)") +# else: +# warnings.warn("xFormers is disabled (SwiGLU)") +# raise ImportError +# except ImportError: +SwiGLU = SwiGLUFFN +XFORMERS_AVAILABLE = False + +# warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/mapanything/models/external/vggt/layers/vision_transformer.py b/mapanything/models/external/vggt/layers/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3d611608c29e528f1c6984768874e117ff660edb --- /dev/null +++ b/mapanything/models/external/vggt/layers/vision_transformer.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import math +from functools import partial +from typing import Callable, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.utils.checkpoint import checkpoint + +from . import ( + MemEffAttention, + Mlp, + NestedTensorBlock as Block, + PatchEmbed, + SwiGLUFFNFused, +) + +logger = logging.getLogger("dinov2") + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + qk_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + # tricky but makes it work + self.use_checkpoint = False + # + + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append( + [nn.Identity()] * i + blocks_list[i : i + chunksize] + ) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [ + self.prepare_tokens_with_masks(x, masks) + for x, masks in zip(x_list, masks_list) + ] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), ( + f"only {len(output)} / {len(blocks_to_take)} blocks found" + ) + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), ( + f"only {len(output)} / {len(blocks_to_take)} blocks found" + ) + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=True, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/mapanything/models/external/vggt/models/__init__.py b/mapanything/models/external/vggt/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/vggt/models/aggregator.py b/mapanything/models/external/vggt/models/aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..e082d34dc97f1dc58e21969de277086a218acdd3 --- /dev/null +++ b/mapanything/models/external/vggt/models/aggregator.py @@ -0,0 +1,385 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Tuple + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from mapanything.models.external.vggt.layers import PatchEmbed +from mapanything.models.external.vggt.layers.block import Block +from mapanything.models.external.vggt.layers.rope import ( + PositionGetter, + RotaryPositionEmbedding2D, +) + +logger = logging.getLogger(__name__) + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class Aggregator(nn.Module): + """ + The Aggregator applies alternating-attention over input frames, + as described in VGGT: Visual Geometry Grounded Transformer. + + + Args: + img_size (int): Image size in pixels. + patch_size (int): Size of each patch for PatchEmbed. + embed_dim (int): Dimension of the token embeddings. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. + num_register_tokens (int): Number of register tokens. + block_fn (nn.Module): The block type used for attention (Block by default). + qkv_bias (bool): Whether to include bias in QKV projections. + proj_bias (bool): Whether to include bias in the output projection. + ffn_bias (bool): Whether to include bias in MLP layers. + patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". + aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. + aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. + qk_norm (bool): Whether to apply QK normalization. + rope_freq (int): Base frequency for rotary embedding. -1 to disable. + init_values (float): Init scale for layer scale. + """ + + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + num_register_tokens=4, + block_fn=Block, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + patch_embed="dinov2_vitl14_reg", + aa_order=["frame", "global"], + aa_block_size=1, + qk_norm=True, + rope_freq=100, + init_values=0.01, + ): + super().__init__() + + self.__build_patch_embed__( + patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim + ) + + # Initialize rotary position embedding if frequency > 0 + self.rope = ( + RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + ) + self.position_getter = PositionGetter() if self.rope is not None else None + + self.frame_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.global_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.depth = depth + self.aa_order = aa_order + self.patch_size = patch_size + self.aa_block_size = aa_block_size + + # Validate that depth is divisible by aa_block_size + if self.depth % self.aa_block_size != 0: + raise ValueError( + f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})" + ) + + self.aa_block_num = self.depth // self.aa_block_size + + # Note: We have two camera tokens, one for the first frame and one for the rest + # The same applies for register tokens + self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) + self.register_token = nn.Parameter( + torch.randn(1, 2, num_register_tokens, embed_dim) + ) + + # The patch tokens start after the camera and register tokens + self.patch_start_idx = 1 + num_register_tokens + + # Initialize parameters with small values + nn.init.normal_(self.camera_token, std=1e-6) + nn.init.normal_(self.register_token, std=1e-6) + + # Register normalization constants as buffers + for name, value in ( + ("_resnet_mean", _RESNET_MEAN), + ("_resnet_std", _RESNET_STD), + ): + self.register_buffer( + name, + torch.FloatTensor(value).view(1, 1, 3, 1, 1), + persistent=False, + ) + + def __build_patch_embed__( + self, + patch_embed, + img_size, + patch_size, + num_register_tokens, + interpolate_antialias=True, + interpolate_offset=0.0, + block_chunks=0, + init_values=1.0, + embed_dim=1024, + ): + """ + Build the patch embed layer. If 'conv', we use a + simple PatchEmbed conv layer. Otherwise, we use a vision transformer. + """ + + if "conv" in patch_embed: + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=3, + embed_dim=embed_dim, + ) + else: + ### From original VGGT codebase: Doesn't load pre-trained DINOv2 weights + # vit_models = { + # "dinov2_vitl14_reg": vit_large, + # "dinov2_vitb14_reg": vit_base, + # "dinov2_vits14_reg": vit_small, + # "dinov2_vitg2_reg": vit_giant2, + # } + + # self.patch_embed = vit_models[patch_embed]( + # img_size=img_size, + # patch_size=patch_size, + # num_register_tokens=num_register_tokens, + # interpolate_antialias=interpolate_antialias, + # interpolate_offset=interpolate_offset, + # block_chunks=block_chunks, + # init_values=init_values, + # ) + + ### Use pre-trained DINOv2 with gradient checkpointing + self.patch_embed = torch.hub.load("facebookresearch/dinov2", patch_embed) + for i in range(len(self.patch_embed.blocks)): + self.patch_embed.blocks[i] = ( + self.wrap_module_with_gradient_checkpointing( + self.patch_embed.blocks[i] + ) + ) + + # Disable gradient updates for mask token + if hasattr(self.patch_embed, "mask_token"): + self.patch_embed.mask_token.requires_grad_(False) + + ### Gradient Checkpointing Wrapper from UniCeption: + def wrap_module_with_gradient_checkpointing(self, module: nn.Module): + """ + Wrapper for Gradient Checkpointing + References: https://github.com/microsoft/MoGe + """ + + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + def forward( + self, + images: torch.Tensor, + ) -> Tuple[List[torch.Tensor], int]: + """ + Args: + images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + + Returns: + (list[torch.Tensor], int): + The list of outputs from the attention blocks, + and the patch_start_idx indicating where patch tokens begin. + """ + B, S, C_in, H, W = images.shape + + if C_in != 3: + raise ValueError(f"Expected 3 input channels, got {C_in}") + + # Normalize images and reshape for patch embed + images = (images - self._resnet_mean) / self._resnet_std + + # Reshape to [B*S, C, H, W] for patch embedding + images = images.view(B * S, C_in, H, W) + patch_tokens = self.patch_embed.forward_features(images) + + if isinstance(patch_tokens, dict): + patch_tokens = patch_tokens["x_norm_patchtokens"] + + _, P, C = patch_tokens.shape + + # Expand camera and register tokens to match batch size and sequence length + camera_token = slice_expand_and_flatten(self.camera_token, B, S) + register_token = slice_expand_and_flatten(self.register_token, B, S) + + # Concatenate special tokens with patch tokens + tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) + + pos = None + if self.rope is not None: + pos = self.position_getter( + B * S, H // self.patch_size, W // self.patch_size, device=images.device + ) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = ( + torch.zeros(B * S, self.patch_start_idx, 2) + .to(images.device) + .to(pos.dtype) + ) + pos = torch.cat([pos_special, pos], dim=1) + + # update P because we added special tokens + _, P, C = tokens.shape + + frame_idx = 0 + global_idx = 0 + output_list = [] + + for _ in range(self.aa_block_num): + for attn_type in self.aa_order: + if attn_type == "frame": + tokens, frame_idx, frame_intermediates = ( + self._process_frame_attention( + tokens, B, S, P, C, frame_idx, pos=pos + ) + ) + elif attn_type == "global": + tokens, global_idx, global_intermediates = ( + self._process_global_attention( + tokens, B, S, P, C, global_idx, pos=pos + ) + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") + + for i in range(len(frame_intermediates)): + # concat frame and global intermediates, [B x S x P x 2C] + concat_inter = torch.cat( + [frame_intermediates[i], global_intermediates[i]], dim=-1 + ) + output_list.append(concat_inter) + + del concat_inter + del frame_intermediates + del global_intermediates + return output_list, self.patch_start_idx + + def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): + """ + Process frame attention blocks. We keep tokens in shape (B*S, P, C). + """ + # If needed, reshape tokens or positions: + if tokens.shape != (B * S, P, C): + tokens = tokens.view(B, S, P, C).view(B * S, P, C) + + if pos is not None and pos.shape != (B * S, P, 2): + pos = pos.view(B, S, P, 2).view(B * S, P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + tokens = self.frame_blocks[frame_idx](tokens, pos=pos) + frame_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, frame_idx, intermediates + + def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): + """ + Process global attention blocks. We keep tokens in shape (B, S*P, C). + """ + if tokens.shape != (B, S * P, C): + tokens = tokens.view(B, S, P, C).view(B, S * P, C) + + if pos is not None and pos.shape != (B, S * P, 2): + pos = pos.view(B, S, P, 2).view(B, S * P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + tokens = self.global_blocks[global_idx](tokens, pos=pos) + global_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, global_idx, intermediates + + +def slice_expand_and_flatten(token_tensor, B, S): + """ + Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: + 1) Uses the first position (index=0) for the first frame only + 2) Uses the second position (index=1) for all remaining frames (S-1 frames) + 3) Expands both to match batch size B + 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token + followed by (S-1) second-position tokens + 5) Flattens to (B*S, X, C) for processing + + Returns: + torch.Tensor: Processed tokens with shape (B*S, X, C) + """ + + # Slice out the "query" tokens => shape (1, 1, ...) + query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) + # Slice out the "other" tokens => shape (1, S-1, ...) + others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) + # Concatenate => shape (B, S, ...) + combined = torch.cat([query, others], dim=1) + + # Finally flatten => shape (B*S, ...) + combined = combined.view(B * S, *combined.shape[2:]) + return combined diff --git a/mapanything/models/external/vggt/models/vggt.py b/mapanything/models/external/vggt/models/vggt.py new file mode 100644 index 0000000000000000000000000000000000000000..dba4b6aeda98ed87eee6ab893346e037abf45095 --- /dev/null +++ b/mapanything/models/external/vggt/models/vggt.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin # used for model hub + +from mapanything.models.external.vggt.heads.camera_head import CameraHead +from mapanything.models.external.vggt.heads.dpt_head import DPTHead +from mapanything.models.external.vggt.heads.track_head import TrackHead +from mapanything.models.external.vggt.models.aggregator import Aggregator + + +class VGGT(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + intermediate_layer_idx=[4, 11, 17, 23], + ): + super().__init__() + + self.aggregator = Aggregator( + img_size=img_size, + patch_size=patch_size, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + ) + self.camera_head = CameraHead(dim_in=2 * embed_dim) + self.point_head = DPTHead( + dim_in=2 * embed_dim, + output_dim=4, + activation="inv_log", + conf_activation="expp1", + intermediate_layer_idx=intermediate_layer_idx, + ) + self.depth_head = DPTHead( + dim_in=2 * embed_dim, + output_dim=2, + activation="exp", + conf_activation="expp1", + intermediate_layer_idx=intermediate_layer_idx, + ) + self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) + + def forward( + self, + images: torch.Tensor, + query_points: torch.Tensor = None, + ): + """ + Forward pass of the VGGT model. + + Args: + images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. + Shape: [N, 2] or [B, N, 2], where N is the number of query points. + Default: None + + Returns: + dict: A dictionary containing the following predictions: + - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) + - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] + - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] + - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] + - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] + - images (torch.Tensor): Original input images, preserved for visualization + + If query_points is provided, also includes: + - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates + - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] + - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] + """ + + # If without batch dimension, add it + if len(images.shape) == 4: + images = images.unsqueeze(0) + if query_points is not None and len(query_points.shape) == 2: + query_points = query_points.unsqueeze(0) + + aggregated_tokens_list, patch_start_idx = self.aggregator(images) + + predictions = {} + + with torch.cuda.amp.autocast(enabled=False): + if self.camera_head is not None: + pose_enc_list = self.camera_head(aggregated_tokens_list) + predictions["pose_enc"] = pose_enc_list[ + -1 + ] # pose encoding of the last iteration + + if self.depth_head is not None: + depth, depth_conf = self.depth_head( + aggregated_tokens_list, + images=images, + patch_start_idx=patch_start_idx, + ) + predictions["depth"] = depth + predictions["depth_conf"] = depth_conf + + if self.point_head is not None: + pts3d, pts3d_conf = self.point_head( + aggregated_tokens_list, + images=images, + patch_start_idx=patch_start_idx, + ) + predictions["world_points"] = pts3d + predictions["world_points_conf"] = pts3d_conf + + if self.track_head is not None and query_points is not None: + track_list, vis, conf = self.track_head( + aggregated_tokens_list, + images=images, + patch_start_idx=patch_start_idx, + query_points=query_points, + ) + predictions["track"] = track_list[-1] # track of the last iteration + predictions["vis"] = vis + predictions["conf"] = conf + + predictions["images"] = images + + return predictions diff --git a/mapanything/models/external/vggt/utils/__init__.py b/mapanything/models/external/vggt/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/models/external/vggt/utils/geometry.py b/mapanything/models/external/vggt/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..17da2b863022a36afc1ad7607e6640120d04b389 --- /dev/null +++ b/mapanything/models/external/vggt/utils/geometry.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch + + +def unproject_depth_map_to_point_map( + depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray +) -> np.ndarray: + """ + Unproject a batch of depth maps to 3D world coordinates. + + Args: + depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) + extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) + intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) + + Returns: + np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) + """ + if isinstance(depth_map, torch.Tensor): + depth_map = depth_map.cpu().numpy() + if isinstance(extrinsics_cam, torch.Tensor): + extrinsics_cam = extrinsics_cam.cpu().numpy() + if isinstance(intrinsics_cam, torch.Tensor): + intrinsics_cam = intrinsics_cam.cpu().numpy() + + world_points_list = [] + for frame_idx in range(depth_map.shape[0]): + cur_world_points, _, _ = depth_to_world_coords_points( + depth_map[frame_idx].squeeze(-1), + extrinsics_cam[frame_idx], + intrinsics_cam[frame_idx], + ) + world_points_list.append(cur_world_points) + world_points_array = np.stack(world_points_list, axis=0) + + return world_points_array + + +def depth_to_world_coords_points( + depth_map: np.ndarray, + extrinsic: np.ndarray, + intrinsic: np.ndarray, + eps=1e-8, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). + """ + if depth_map is None: + return None, None, None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + # Multiply with the inverse of extrinsic matrix to transform to world coordinates + # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + # Apply the rotation and translation to the camera coordinates + world_coords_points = ( + np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world + ) # HxWx3, 3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points( + depth_map: np.ndarray, intrinsic: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + + Returns: + tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, ( + "Intrinsic matrix must have zero skew" + ) + + # Intrinsic parameters + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + # Generate grid of pixel coordinates + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + # Unproject to camera coordinates + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + # Stack to form camera coordinates + cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + return cam_coords + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + # Check if se3 is a numpy array or a torch tensor + is_numpy = isinstance(se3, np.ndarray) + + # Validate shapes + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + # Extract R and T if not provided + if R is None: + R = se3[:, :3, :3] # (N,3,3) + if T is None: + T = se3[:, :3, 3:] # (N,3,1) + + # Transpose R + if is_numpy: + # Compute the transpose of the rotation for NumPy + R_transposed = np.transpose(R, (0, 2, 1)) + # -R^T t for NumPy + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.transpose(1, 2) # (N,3,3) + top_right = -torch.bmm(R_transposed, T) # (N,3,1) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix diff --git a/mapanything/models/external/vggt/utils/load_fn.py b/mapanything/models/external/vggt/utils/load_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..afa0a2d493a39759e01ee2babbd29ad19b19d215 --- /dev/null +++ b/mapanything/models/external/vggt/utils/load_fn.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from PIL import Image +from torchvision import transforms as TF + + +def load_and_preprocess_images(image_path_list, mode="crop"): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. + + Args: + image_path_list (list): List of paths to image files + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + # Validate mode + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = 518 + + # First process all images and collect their shapes + for image_path in image_path_list: + # Open image + img = Image.open(image_path) + + # If there's an alpha channel, blend onto white background: + if img.mode == "RGBA": + # Create white background + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + # Alpha composite onto the white background + img = Image.alpha_composite(background, img) + + # Now convert to "RGB" (this step assigns white for transparent areas) + img = img.convert("RGB") + + width, height = img.size + + if mode == "pad": + # Make the largest dimension 518px while maintaining aspect ratio + if width >= height: + new_width = target_size + new_height = ( + round(height * (new_width / width) / 14) * 14 + ) # Make divisible by 14 + else: + new_height = target_size + new_width = ( + round(width * (new_height / height) / 14) * 14 + ) # Make divisible by 14 + else: # mode == "crop" + # Original behavior: set width to 518px + new_width = target_size + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / 14) * 14 + + # Resize with new dimensions (width, height) + img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 (only in crop mode) + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + + # For pad mode, pad to make a square of target_size x target_size + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + # Pad with white (value=1.0) + img = torch.nn.functional.pad( + img, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=1.0, + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + if len(shapes) > 1: + print(f"Warning: Found images with different shapes: {shapes}") + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=1.0, + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + + # Ensure correct shape when single image + if len(image_path_list) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + + return images diff --git a/mapanything/models/external/vggt/utils/pose_enc.py b/mapanything/models/external/vggt/utils/pose_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..b83ad40498229f3bae2143a8fc6c742de27a2264 --- /dev/null +++ b/mapanything/models/external/vggt/utils/pose_enc.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .rotation import mat_to_quat, quat_to_mat + + +def extri_intri_to_pose_encoding( + extrinsics, + intrinsics, + image_size_hw=None, # e.g., (256, 512) + pose_encoding_type="absT_quaR_FoV", +): + """Convert camera extrinsics and intrinsics to a compact pose encoding. + + This function transforms camera parameters into a unified pose encoding format, + which can be used for various downstream tasks like pose prediction or representation. + + Args: + extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, + where B is batch size and S is sequence length. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. + The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. + intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. + Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for computing field of view values. For example: (256, 512). + pose_encoding_type (str): Type of pose encoding to use. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + + Returns: + torch.Tensor: Encoded camera pose parameters with shape BxSx9. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + """ + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + + if pose_encoding_type == "absT_quaR_FoV": + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat( + [T, quat, fov_h[..., None], fov_w[..., None]], dim=-1 + ).float() + else: + raise NotImplementedError + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, + image_size_hw=None, # e.g., (256, 512) + pose_encoding_type="absT_quaR_FoV", + build_intrinsics=True, +): + """Convert a pose encoding back to camera extrinsics and intrinsics. + + This function performs the inverse operation of extri_intri_to_pose_encoding, + reconstructing the full camera parameters from the compact encoding. + + Args: + pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, + where B is batch size and S is sequence length. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for reconstructing intrinsics from field of view values. + For example: (256, 512). + pose_encoding_type (str): Type of pose encoding used. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. + If False, only extrinsics are returned and intrinsics will be None. + + Returns: + tuple: (extrinsics, intrinsics) + - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world + transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is + a 3x1 translation vector. + - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, + or None if build_intrinsics is False. Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point, + assumed to be at the center of the image (W/2, H/2). + """ + + intrinsics = None + + if pose_encoding_type == "absT_quaR_FoV": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + if build_intrinsics: + H, W = image_size_hw + fy = (H / 2.0) / torch.tan(fov_h / 2.0) + fx = (W / 2.0) / torch.tan(fov_w / 2.0) + intrinsics = torch.zeros( + pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device + ) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + else: + raise NotImplementedError + + return extrinsics, intrinsics diff --git a/mapanything/models/external/vggt/utils/rotation.py b/mapanything/models/external/vggt/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5e25dc36dfe7fe8472d536105fb449b7044971 --- /dev/null +++ b/mapanything/models/external/vggt/utils/rotation.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d + +import torch +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/mapanything/models/external/vggt/utils/visual_track.py b/mapanything/models/external/vggt/utils/visual_track.py new file mode 100644 index 0000000000000000000000000000000000000000..0d4c314b016cc074e5d6640e4b4af07acc1a1699 --- /dev/null +++ b/mapanything/models/external/vggt/utils/visual_track.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import cv2 +import numpy as np +import torch + + +def color_from_xy(x, y, W, H, cmap_name="hsv"): + """ + Map (x, y) -> color in (R, G, B). + 1) Normalize x,y to [0,1]. + 2) Combine them into a single scalar c in [0,1]. + 3) Use matplotlib's colormap to convert c -> (R,G,B). + + You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). + """ + import matplotlib.cm + import matplotlib.colors + + x_norm = x / max(W - 1, 1) + y_norm = y / max(H - 1, 1) + # Simple combination: + c = (x_norm + y_norm) / 2.0 + + cmap = matplotlib.cm.get_cmap(cmap_name) + # cmap(c) -> (r,g,b,a) in [0,1] + rgba = cmap(c) + r, g, b = rgba[0], rgba[1], rgba[2] + return (r, g, b) # in [0,1], RGB order + + +def get_track_colors_by_position( + tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv" +): + """ + Given all tracks in one sample (b), compute a (N,3) array of RGB color values + in [0,255]. The color is determined by the (x,y) position in the first + visible frame for each track. + + Args: + tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. + vis_mask_b: (S, N) boolean mask; if None, assume all are visible. + image_width, image_height: used for normalizing (x, y). + cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). + + Returns: + track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. + """ + S, N, _ = tracks_b.shape + track_colors = np.zeros((N, 3), dtype=np.uint8) + + if vis_mask_b is None: + # treat all as visible + vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) + + for i in range(N): + # Find first visible frame for track i + visible_frames = torch.where(vis_mask_b[:, i])[0] + if len(visible_frames) == 0: + # track is never visible; just assign black or something + track_colors[i] = (0, 0, 0) + continue + + first_s = int(visible_frames[0].item()) + # use that frame's (x,y) + x, y = tracks_b[first_s, i].tolist() + + # map (x,y) -> (R,G,B) in [0,1] + r, g, b = color_from_xy( + x, y, W=image_width, H=image_height, cmap_name=cmap_name + ) + # scale to [0,255] + r, g, b = int(r * 255), int(g * 255), int(b * 255) + track_colors[i] = (r, g, b) + + return track_colors + + +def visualize_tracks_on_images( + images, + tracks, + track_vis_mask=None, + out_dir="track_visuals_concat_by_xy", + image_format="CHW", # "CHW" or "HWC" + normalize_mode="[0,1]", + cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" + frames_per_row=4, # New parameter for grid layout + save_grid=True, # Flag to control whether to save the grid image +): + """ + Visualizes frames in a grid layout with specified frames per row. + Each track's color is determined by its (x,y) position + in the first visible frame (or frame 0 if always visible). + Finally convert the BGR result to RGB before saving. + Also saves each individual frame as a separate PNG file. + + Args: + images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. + tracks: torch.Tensor (S, N, 2), last dim = (x, y). + track_vis_mask: torch.Tensor (S, N) or None. + out_dir: folder to save visualizations. + image_format: "CHW" or "HWC". + normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 + cmap_name: a matplotlib colormap name for color_from_xy. + frames_per_row: number of frames to display in each row of the grid. + save_grid: whether to save all frames in one grid image. + + Returns: + None (saves images in out_dir). + """ + + if len(tracks.shape) == 4: + tracks = tracks.squeeze(0) + images = images.squeeze(0) + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.squeeze(0) + + import matplotlib + + matplotlib.use("Agg") # for non-interactive (optional) + + os.makedirs(out_dir, exist_ok=True) + + S = images.shape[0] + _, N, _ = tracks.shape # (S, N, 2) + + # Move to CPU + images = images.cpu().clone() + tracks = tracks.cpu().clone() + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.cpu().clone() + + # Infer H, W from images shape + if image_format == "CHW": + # e.g. images[s].shape = (3, H, W) + H, W = images.shape[2], images.shape[3] + else: + # e.g. images[s].shape = (H, W, 3) + H, W = images.shape[1], images.shape[2] + + # Pre-compute the color for each track i based on first visible position + track_colors_rgb = get_track_colors_by_position( + tracks, # shape (S, N, 2) + vis_mask_b=track_vis_mask if track_vis_mask is not None else None, + image_width=W, + image_height=H, + cmap_name=cmap_name, + ) + + # We'll accumulate each frame's drawn image in a list + frame_images = [] + + for s in range(S): + # shape => either (3, H, W) or (H, W, 3) + img = images[s] + + # Convert to (H, W, 3) + if image_format == "CHW": + img = img.permute(1, 2, 0) # (H, W, 3) + # else "HWC", do nothing + + img = img.numpy().astype(np.float32) + + # Scale to [0,255] if needed + if normalize_mode == "[0,1]": + img = np.clip(img, 0, 1) * 255.0 + elif normalize_mode == "[-1,1]": + img = (img + 1.0) * 0.5 * 255.0 + img = np.clip(img, 0, 255.0) + # else no normalization + + # Convert to uint8 + img = img.astype(np.uint8) + + # For drawing in OpenCV, convert to BGR + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + # Draw each visible track + cur_tracks = tracks[s] # shape (N, 2) + if track_vis_mask is not None: + valid_indices = torch.where(track_vis_mask[s])[0] + else: + valid_indices = range(N) + + cur_tracks_np = cur_tracks.numpy() + for i in valid_indices: + x, y = cur_tracks_np[i] + pt = (int(round(x)), int(round(y))) + + # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR + R, G, B = track_colors_rgb[i] + color_bgr = (int(B), int(G), int(R)) + cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) + + # Convert back to RGB for consistent final saving: + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + # Save individual frame + frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") + # Convert to BGR for OpenCV imwrite + frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(frame_path, frame_bgr) + + frame_images.append(img_rgb) + + # Only create and save the grid image if save_grid is True + if save_grid: + # Calculate grid dimensions + num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division + + # Create a grid of images + grid_img = None + for row in range(num_rows): + start_idx = row * frames_per_row + end_idx = min(start_idx + frames_per_row, S) + + # Concatenate this row horizontally + row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) + + # If this row has fewer than frames_per_row images, pad with black + if end_idx - start_idx < frames_per_row: + padding_width = (frames_per_row - (end_idx - start_idx)) * W + padding = np.zeros((H, padding_width, 3), dtype=np.uint8) + row_img = np.concatenate([row_img, padding], axis=1) + + # Add this row to the grid + if grid_img is None: + grid_img = row_img + else: + grid_img = np.concatenate([grid_img, row_img], axis=0) + + out_path = os.path.join(out_dir, "tracks_grid.png") + # Convert back to BGR for OpenCV imwrite + grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_path, grid_img_bgr) + print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") + + print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png") diff --git a/mapanything/models/mapanything/__init__.py b/mapanything/models/mapanything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc21520d217b751b6cb98a35674ca4af997c65f --- /dev/null +++ b/mapanything/models/mapanything/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mapanything.models.mapanything.ablations import MapAnythingAblations +from mapanything.models.mapanything.model import MapAnything +from mapanything.models.mapanything.modular_dust3r import ModularDUSt3R + +__all__ = [ + "MapAnything", + "MapAnythingAblations", + "ModularDUSt3R", +] diff --git a/mapanything/models/mapanything/__pycache__/__init__.cpython-312.pyc b/mapanything/models/mapanything/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa9f43ba532f38ee0f1ca02c08a524bdae34d5f6 Binary files /dev/null and b/mapanything/models/mapanything/__pycache__/__init__.cpython-312.pyc differ diff --git a/mapanything/models/mapanything/__pycache__/ablations.cpython-312.pyc b/mapanything/models/mapanything/__pycache__/ablations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d82f6c1d7a1744a8d7828e79106fe0d5ed9de61d Binary files /dev/null and b/mapanything/models/mapanything/__pycache__/ablations.cpython-312.pyc differ diff --git a/mapanything/models/mapanything/__pycache__/model.cpython-312.pyc b/mapanything/models/mapanything/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14e749123a0eb4608b34dac7cbd11b3ba319f3e1 Binary files /dev/null and b/mapanything/models/mapanything/__pycache__/model.cpython-312.pyc differ diff --git a/mapanything/models/mapanything/__pycache__/modular_dust3r.cpython-312.pyc b/mapanything/models/mapanything/__pycache__/modular_dust3r.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfa468cd3d0f290b9ef7430de5cca69c48ba9899 Binary files /dev/null and b/mapanything/models/mapanything/__pycache__/modular_dust3r.cpython-312.pyc differ diff --git a/mapanything/models/mapanything/ablations.py b/mapanything/models/mapanything/ablations.py new file mode 100644 index 0000000000000000000000000000000000000000..5b23ffe1b62acaead4eb11ad8336e3d71f4ae102 --- /dev/null +++ b/mapanything/models/mapanything/ablations.py @@ -0,0 +1,1660 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MapAnything Ablation model classes defined using UniCeption modules. +""" + +from functools import partial +from typing import Callable, Dict, Type, Union + +import torch +import torch.nn as nn + +from mapanything.utils.geometry import ( + apply_log_to_norm, + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + normalize_depth_using_non_zero_pixels, + normalize_pose_translations, + transform_pose_using_quats_and_trans_2_to_1, +) +from uniception.models.encoders import ( + encoder_factory, + EncoderGlobalRepInput, + ViTEncoderInput, + ViTEncoderNonImageInput, +) +from uniception.models.info_sharing.alternating_attention_transformer import ( + MultiViewAlternatingAttentionTransformer, + MultiViewAlternatingAttentionTransformerIFR, +) +from uniception.models.info_sharing.base import MultiViewTransformerInput +from uniception.models.info_sharing.cross_attention_transformer import ( + MultiViewCrossAttentionTransformer, + MultiViewCrossAttentionTransformerIFR, +) +from uniception.models.info_sharing.global_attention_transformer import ( + MultiViewGlobalAttentionTransformer, + MultiViewGlobalAttentionTransformerIFR, +) +from uniception.models.libs.croco.pos_embed import RoPE2D +from uniception.models.prediction_heads.adaptors import ( + CamTranslationPlusQuatsAdaptor, + PointMapAdaptor, + PointMapPlusRayDirectionsPlusDepthAdaptor, + PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor, + PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor, + PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor, + PointMapWithConfidenceAdaptor, + PointMapWithConfidenceAndMaskAdaptor, + PointMapWithMaskAdaptor, + RayDirectionsPlusDepthAdaptor, + RayDirectionsPlusDepthWithConfidenceAdaptor, + RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor, + RayDirectionsPlusDepthWithMaskAdaptor, + RayMapPlusDepthAdaptor, + RayMapPlusDepthWithConfidenceAdaptor, + RayMapPlusDepthWithConfidenceAndMaskAdaptor, + RayMapPlusDepthWithMaskAdaptor, +) +from uniception.models.prediction_heads.base import ( + AdaptorInput, + PredictionHeadInput, + PredictionHeadLayeredInput, +) +from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor +from uniception.models.prediction_heads.linear import LinearFeature +from uniception.models.prediction_heads.pose_head import PoseHead + +# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) +if hasattr(torch.backends.cuda, "matmul") and hasattr( + torch.backends.cuda.matmul, "allow_tf32" +): + torch.backends.cuda.matmul.allow_tf32 = True + + +class MapAnythingAblations(nn.Module): + "Modular MapAnything Multi-View model class with no scale token." + + def __init__( + self, + name: str, + encoder_config: Dict, + info_sharing_config: Dict, + pred_head_config: Dict, + geometric_input_config: Dict, + fusion_norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial( + nn.LayerNorm, eps=1e-6 + ), + pretrained_checkpoint_path: str = None, + load_specific_pretrained_submodules: bool = False, + specific_pretrained_submodules: list = [], + torch_hub_force_reload: bool = False, + ): + """ + Multi-view model containing an image encoder followed by a multi-view attention transformer and respective downstream heads. + The goal is to output scene representation directly in view 0's frame. + + Args: + name (str): Name of the model. + encoder_config (Dict): Configuration for the encoder. + info_sharing_config (Dict): Configuration for the multi-view attention transformer. + pred_head_config (Dict): Configuration for the prediction heads. + pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None) + load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False) + specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: []) + torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False) + """ + super().__init__() + + # Initialize the attributes + self.name = name + self.encoder_config = encoder_config + self.info_sharing_config = info_sharing_config + self.pred_head_config = pred_head_config + self.geometric_input_config = geometric_input_config + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.load_specific_pretrained_submodules = load_specific_pretrained_submodules + self.specific_pretrained_submodules = specific_pretrained_submodules + self.torch_hub_force_reload = torch_hub_force_reload + self.class_init_args = { + "name": self.name, + "encoder_config": self.encoder_config, + "info_sharing_config": self.info_sharing_config, + "pred_head_config": self.pred_head_config, + "geometric_input_config": self.geometric_input_config, + "pretrained_checkpoint_path": self.pretrained_checkpoint_path, + "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules, + "specific_pretrained_submodules": self.specific_pretrained_submodules, + "torch_hub_force_reload": self.torch_hub_force_reload, + } + + # Get relevant parameters from the configs + self.info_sharing_type = info_sharing_config["model_type"] + self.info_sharing_return_type = info_sharing_config["model_return_type"] + self.pred_head_type = pred_head_config["type"] + + # Initialize image encoder + if self.encoder_config["uses_torch_hub"]: + self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload + # Create a copy of the config before deleting the key to preserve it for serialization + encoder_config_copy = self.encoder_config.copy() + del encoder_config_copy["uses_torch_hub"] + self.encoder = encoder_factory(**encoder_config_copy) + + # Initialize the encoder for ray directions + ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"] + ray_dirs_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + ray_dirs_encoder_config["patch_size"] = self.encoder.patch_size + self.ray_dirs_encoder = encoder_factory(**ray_dirs_encoder_config) + + # Initialize the encoder for depth (normalized per view and values after normalization are scaled logarithmically) + depth_encoder_config = self.geometric_input_config["depth_encoder_config"] + depth_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + depth_encoder_config["patch_size"] = self.encoder.patch_size + self.depth_encoder = encoder_factory(**depth_encoder_config) + + # Initialize the encoder for log scale factor of depth + depth_scale_encoder_config = self.geometric_input_config["scale_encoder_config"] + depth_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.depth_scale_encoder = encoder_factory(**depth_scale_encoder_config) + + # Initialize the encoder for camera rotation + cam_rot_encoder_config = self.geometric_input_config["cam_rot_encoder_config"] + cam_rot_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_rot_encoder = encoder_factory(**cam_rot_encoder_config) + + # Initialize the encoder for camera translation (normalized across all provided camera translations) + cam_trans_encoder_config = self.geometric_input_config[ + "cam_trans_encoder_config" + ] + cam_trans_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_trans_encoder = encoder_factory(**cam_trans_encoder_config) + + # Initialize the encoder for log scale factor of camera translation + cam_trans_scale_encoder_config = self.geometric_input_config[ + "scale_encoder_config" + ] + cam_trans_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_trans_scale_encoder = encoder_factory(**cam_trans_scale_encoder_config) + + # Initialize the fusion norm layer + self.fusion_norm_layer = fusion_norm_layer(self.encoder.enc_embed_dim) + + # Initialize the info sharing module (Multi-View Transformer) + self._initialize_info_sharing(info_sharing_config) + + # Initialize the prediction heads + self._initialize_prediction_heads(pred_head_config) + + # Initialize the final adaptors + self._initialize_adaptors(pred_head_config) + + # Load pretrained weights + self._load_pretrained_weights() + + def _initialize_info_sharing(self, info_sharing_config): + """ + Initialize the information sharing module based on the configuration. + + This method sets up the custom positional encoding if specified and initializes + the appropriate multi-view transformer based on the configuration type. + + Args: + info_sharing_config (Dict): Configuration for the multi-view attention transformer. + Should contain 'custom_positional_encoding', 'model_type', and 'model_return_type'. + + Returns: + None + + Raises: + ValueError: If invalid configuration options are provided. + """ + # Initialize Custom Positional Encoding if required + custom_positional_encoding = info_sharing_config["custom_positional_encoding"] + if custom_positional_encoding is not None: + if isinstance(custom_positional_encoding, str): + print( + f"Using custom positional encoding for multi-view attention transformer: {custom_positional_encoding}" + ) + if custom_positional_encoding.startswith("RoPE"): + rope_freq = float(custom_positional_encoding[len("RoPE") :]) + print(f"RoPE frequency: {rope_freq}") + self.custom_positional_encoding = RoPE2D(freq=rope_freq) + else: + raise ValueError( + f"Invalid custom_positional_encoding: {custom_positional_encoding}." + ) + elif isinstance(custom_positional_encoding, Callable): + print( + "Using callable function as custom positional encoding for multi-view attention transformer." + ) + self.custom_positional_encoding = custom_positional_encoding + else: + self.custom_positional_encoding = None + + # Add dependencies to info_sharing_config + info_sharing_config["module_args"]["input_embed_dim"] = ( + self.encoder.enc_embed_dim + ) + info_sharing_config["module_args"]["custom_positional_encoding"] = ( + self.custom_positional_encoding + ) + + # Initialize Multi-View Transformer + if self.info_sharing_return_type == "no_intermediate_features": + # Returns only normalized last layer features + # Initialize multi-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformer( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + elif self.info_sharing_return_type == "intermediate_features": + # Returns intermediate features and normalized last layer features + # Initialize mulit-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + # Assess if the DPT needs to use encoder features + if len(self.info_sharing.indices) == 2: + self.use_encoder_features_for_dpt = True + elif len(self.info_sharing.indices) == 3: + self.use_encoder_features_for_dpt = False + else: + raise ValueError( + "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices." + ) + else: + raise ValueError( + f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']" + ) + + def _initialize_prediction_heads(self, pred_head_config): + """ + Initialize the prediction heads based on the prediction head configuration. + + This method configures and initializes the appropriate prediction heads based on the + specified prediction head type (linear, DPT, or DPT+pose). It sets up the necessary + dependencies and creates the required model components. + + Args: + pred_head_config (Dict): Configuration for the prediction heads. + + Returns: + None + + Raises: + ValueError: If an invalid pred_head_type is provided. + """ + # Add dependencies to prediction head config + pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size + if self.pred_head_type == "linear": + pred_head_config["feature_head"]["input_feature_dim"] = ( + self.info_sharing.dim + ) + elif "dpt" in self.pred_head_type: + # Add dependencies for DPT & Regressor head + if self.use_encoder_features_for_dpt: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.encoder.enc_embed_dim + ] + [self.info_sharing.dim] * 3 + else: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.info_sharing.dim + ] * 4 + pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[ + "feature_head" + ]["feature_dim"] + # Add dependencies for Pose head if required + if "pose" in self.pred_head_type: + pred_head_config["pose_head"]["patch_size"] = self.encoder.patch_size + pred_head_config["pose_head"]["input_feature_dim"] = ( + self.info_sharing.dim + ) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + # Initialize Prediction Heads + if self.pred_head_type == "linear": + # Initialize Dense Prediction Head for all views + self.dense_head = LinearFeature(**pred_head_config["feature_head"]) + elif "dpt" in self.pred_head_type: + # Initialize Dense Prediction Head for all views + self.dpt_feature_head = DPTFeature(**pred_head_config["feature_head"]) + self.dpt_regressor_head = DPTRegressionProcessor( + **pred_head_config["regressor_head"] + ) + self.dense_head = nn.Sequential( + self.dpt_feature_head, self.dpt_regressor_head + ) + # Initialize Pose Head for all views if required + if "pose" in self.pred_head_type: + self.pose_head = PoseHead(**pred_head_config["pose_head"]) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + def _initialize_adaptors(self, pred_head_config): + """ + Initialize the adaptors based on the prediction head configuration. + + This method sets up the appropriate adaptors for different scene representation types, + such as pointmaps, ray maps with depth, or ray directions with depth and pose. + + Args: + pred_head_config (Dict): Configuration for the prediction heads including adaptor type. + + Returns: + None + + Raises: + ValueError: If an invalid adaptor_type is provided. + AssertionError: If ray directions + depth + pose is used with an incompatible head type. + """ + if pred_head_config["adaptor_type"] == "pointmap": + self.dense_adaptor = PointMapAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "pointmap" + elif pred_head_config["adaptor_type"] == "pointmap+confidence": + self.dense_adaptor = PointMapWithConfidenceAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "pointmap+confidence" + elif pred_head_config["adaptor_type"] == "pointmap+mask": + self.dense_adaptor = PointMapWithMaskAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "pointmap+mask" + elif pred_head_config["adaptor_type"] == "pointmap+confidence+mask": + self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "pointmap+confidence+mask" + elif pred_head_config["adaptor_type"] == "raymap+depth": + self.dense_adaptor = RayMapPlusDepthAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "raymap+depth" + elif pred_head_config["adaptor_type"] == "raymap+depth+confidence": + self.dense_adaptor = RayMapPlusDepthWithConfidenceAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+confidence" + elif pred_head_config["adaptor_type"] == "raymap+depth+mask": + self.dense_adaptor = RayMapPlusDepthWithMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+mask" + elif pred_head_config["adaptor_type"] == "raymap+depth+confidence+mask": + self.dense_adaptor = RayMapPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+confidence+mask" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+confidence" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+mask" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence+mask": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+confidence+mask" + elif pred_head_config["adaptor_type"] == "campointmap+pose": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapAdaptor(**pred_head_config["dpt_adaptor"]) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose" + elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+confidence" + elif pred_head_config["adaptor_type"] == "campointmap+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+mask" + elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence+mask": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+confidence+mask" + elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose": + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose" + elif ( + pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+confidence" + ): + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = ( + PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence" + elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+mask" + elif ( + pred_head_config["adaptor_type"] + == "pointmap+raydirs+depth+pose+confidence+mask" + ): + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = ( + PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence+mask" + else: + raise ValueError( + f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. \ + Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \ + 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \ + 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \ + 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']" + ) + + def _load_pretrained_weights(self): + """ + Load pretrained weights from a checkpoint file. + + If load_specific_pretrained_submodules is True, only loads weights for the specified submodules. + Otherwise, loads all weights from the checkpoint. + + Returns: + None + """ + if self.pretrained_checkpoint_path is not None: + if not self.load_specific_pretrained_submodules: + print( + f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + else: + print( + f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} for specific submodules: {self.specific_pretrained_submodules} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + filtered_ckpt = {} + for ckpt_key, ckpt_value in ckpt["model"].items(): + for submodule in self.specific_pretrained_submodules: + if ckpt_key.startswith(submodule): + filtered_ckpt[ckpt_key] = ckpt_value + print(self.load_state_dict(filtered_ckpt, strict=False)) + + def _encode_n_views(self, views): + """ + Encode all the input views (batch of images) in a single forward pass. + Assumes all the input views have the same image shape, batch size, and data normalization type. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + + Returns: + List[torch.Tensor]: A list containing the encoded features for all N views. + """ + num_views = len(views) + data_norm_type = views[0]["data_norm_type"][0] + imgs_list = [view["img"] for view in views] + all_imgs_across_views = torch.cat(imgs_list, dim=0) + encoder_input = ViTEncoderInput( + image=all_imgs_across_views, data_norm_type=data_norm_type + ) + encoder_output = self.encoder(encoder_input) + all_encoder_features_across_views = encoder_output.features.chunk( + num_views, dim=0 + ) + + return all_encoder_features_across_views + + def _compute_pose_quats_and_trans_for_across_views_in_ref_view( + self, + views, + num_views, + device, + dtype, + batch_size_per_view, + per_sample_cam_input_mask, + ): + """ + Compute the pose quats and trans for all the views in the frame of the reference view 0. + Returns identity pose for views where the camera input mask is False or the pose is not provided. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + device (torch.device): Device to use for the computation. + dtype (torch.dtype): Data type to use for the computation. + per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask. + + Returns: + torch.Tensor: A tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4) + torch.Tensor: A tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3) + torch.Tensor: A tensor containing the per sample camera input mask. + """ + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + pose_quats_non_ref_views = [] + pose_trans_non_ref_views = [] + pose_quats_ref_view_0 = [] + pose_trans_ref_view_0 = [] + for view_idx in range(num_views): + per_sample_cam_input_mask_for_curr_view = per_sample_cam_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] + if ( + "camera_pose_quats" in views[view_idx] + and "camera_pose_trans" in views[view_idx] + and per_sample_cam_input_mask_for_curr_view.any() + ): + # Get the camera pose quats and trans for the current view + cam_pose_quats = views[view_idx]["camera_pose_quats"][ + per_sample_cam_input_mask_for_curr_view + ] + cam_pose_trans = views[view_idx]["camera_pose_trans"][ + per_sample_cam_input_mask_for_curr_view + ] + # Append to the list + pose_quats_non_ref_views.append(cam_pose_quats) + pose_trans_non_ref_views.append(cam_pose_trans) + # Get the camera pose quats and trans for the reference view 0 + cam_pose_quats = views[0]["camera_pose_quats"][ + per_sample_cam_input_mask_for_curr_view + ] + cam_pose_trans = views[0]["camera_pose_trans"][ + per_sample_cam_input_mask_for_curr_view + ] + # Append to the list + pose_quats_ref_view_0.append(cam_pose_quats) + pose_trans_ref_view_0.append(cam_pose_trans) + else: + per_sample_cam_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + + # Initialize the pose quats and trans for all views as identity + pose_quats_across_views = torch.tensor( + [0.0, 0.0, 0.0, 1.0], dtype=dtype, device=device + ).repeat(batch_size_per_view * num_views, 1) # (q_x, q_y, q_z, q_w) + pose_trans_across_views = torch.zeros( + (batch_size_per_view * num_views, 3), dtype=dtype, device=device + ) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + if len(pose_quats_non_ref_views) > 0: + # Stack the pose quats and trans for all the non-reference views and reference view 0 + pose_quats_non_ref_views = torch.cat(pose_quats_non_ref_views, dim=0) + pose_trans_non_ref_views = torch.cat(pose_trans_non_ref_views, dim=0) + pose_quats_ref_view_0 = torch.cat(pose_quats_ref_view_0, dim=0) + pose_trans_ref_view_0 = torch.cat(pose_trans_ref_view_0, dim=0) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + ( + pose_quats_non_ref_views_in_ref_view_0, + pose_trans_non_ref_views_in_ref_view_0, + ) = transform_pose_using_quats_and_trans_2_to_1( + pose_quats_ref_view_0, + pose_trans_ref_view_0, + pose_quats_non_ref_views, + pose_trans_non_ref_views, + ) + + # Update the pose quats and trans for all the non-reference views + pose_quats_across_views[per_sample_cam_input_mask] = ( + pose_quats_non_ref_views_in_ref_view_0.to(dtype=dtype) + ) + pose_trans_across_views[per_sample_cam_input_mask] = ( + pose_trans_non_ref_views_in_ref_view_0.to(dtype=dtype) + ) + + return ( + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ) + + def _encode_and_fuse_ray_dirs( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_ray_dirs_input_mask, + ): + """ + Encode the ray directions for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + per_sample_ray_dirs_input_mask (torch.Tensor): Tensor containing the per sample ray direction input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Get the height and width of the images + _, _, height, width = views[0]["img"].shape + + # Get the ray directions for all the views where info is provided and the ray direction input mask is True + ray_dirs_list = [] + for view_idx in range(num_views): + per_sample_ray_dirs_input_mask_for_curr_view = ( + per_sample_ray_dirs_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] + ) + ray_dirs_for_curr_view = torch.zeros( + (batch_size_per_view, height, width, 3), + dtype=all_encoder_features_across_views.dtype, + device=all_encoder_features_across_views.device, + ) + if ( + "ray_directions_cam" in views[view_idx] + and per_sample_ray_dirs_input_mask_for_curr_view.any() + ): + ray_dirs_for_curr_view[per_sample_ray_dirs_input_mask_for_curr_view] = ( + views[view_idx]["ray_directions_cam"][ + per_sample_ray_dirs_input_mask_for_curr_view + ] + ) + else: + per_sample_ray_dirs_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + ray_dirs_list.append(ray_dirs_for_curr_view) + + # Stack the ray directions for all the views and permute to (B * V, C, H, W) + ray_dirs = torch.cat(ray_dirs_list, dim=0) # (B * V, H, W, 3) + ray_dirs = ray_dirs.permute(0, 3, 1, 2).contiguous() # (B * V, 3, H, W) + + # Encode the ray directions + ray_dirs_features_across_views = self.ray_dirs_encoder( + ViTEncoderNonImageInput(data=ray_dirs) + ).features + + # Fuse the ray direction features with the other encoder features (zero out the features where the ray direction input mask is False) + ray_dirs_features_across_views = ( + ray_dirs_features_across_views + * per_sample_ray_dirs_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + ) + all_encoder_features_across_views = ( + all_encoder_features_across_views + ray_dirs_features_across_views + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_depths( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_depth_input_mask, + ): + """ + Encode the z depths for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + per_sample_depth_input_mask (torch.Tensor): Tensor containing the per sample depth input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Get the device and height and width of the images + device = all_encoder_features_across_views.device + _, _, height, width = views[0]["img"].shape + + # Decide to use randomly sampled sparse depth or dense depth + if torch.rand(1) < self.geometric_input_config["sparse_depth_prob"]: + use_sparse_depth = True + else: + use_sparse_depth = False + + # Get the depths for all the views + depth_list = [] + depth_norm_factors_list = [] + metric_scale_depth_mask_list = [] + for view_idx in range(num_views): + # Get the input mask for current view + per_sample_depth_input_mask_for_curr_view = per_sample_depth_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] + depth_for_curr_view = torch.zeros( + (batch_size_per_view, height, width, 1), + dtype=all_encoder_features_across_views.dtype, + device=device, + ) + depth_norm_factor_for_curr_view = torch.zeros( + (batch_size_per_view), + dtype=all_encoder_features_across_views.dtype, + device=device, + ) + metric_scale_mask_for_curr_view = torch.zeros( + (batch_size_per_view), + dtype=torch.bool, + device=device, + ) + if ( + "depth_along_ray" in views[view_idx] + ) and per_sample_depth_input_mask_for_curr_view.any(): + # Get depth for current view + depth_for_curr_view_input = views[view_idx]["depth_along_ray"][ + per_sample_depth_input_mask_for_curr_view + ] + # Get the metric scale mask + if "is_metric_scale" in views[view_idx]: + metric_scale_mask = views[view_idx]["is_metric_scale"][ + per_sample_depth_input_mask_for_curr_view + ] + else: + metric_scale_mask = torch.zeros( + depth_for_curr_view_input.shape[0], + dtype=torch.bool, + device=device, + ) + # Turn off indication of metric scale samples based on the depth_scale_norm_all_prob + depth_scale_norm_all_mask = ( + torch.rand(metric_scale_mask.shape[0]) + < self.geometric_input_config["depth_scale_norm_all_prob"] + ) + if depth_scale_norm_all_mask.any(): + metric_scale_mask[depth_scale_norm_all_mask] = False + # Assign the metric scale mask to the respective indices + metric_scale_mask_for_curr_view[ + per_sample_depth_input_mask_for_curr_view + ] = metric_scale_mask + # Sparsely sample the depth if required + if use_sparse_depth: + # Create a mask of ones + sparsification_mask = torch.ones_like( + depth_for_curr_view_input, device=device + ) + # Create a mask for valid pixels (depth > 0) + valid_pixel_mask = depth_for_curr_view_input > 0 + # Calculate the number of valid pixels + num_valid_pixels = valid_pixel_mask.sum().item() + # Calculate the number of valid pixels to set to zero + num_to_zero = int( + num_valid_pixels + * self.geometric_input_config["sparsification_removal_percent"] + ) + if num_to_zero > 0: + # Get the indices of valid pixels + valid_indices = valid_pixel_mask.nonzero(as_tuple=True) + # Randomly select indices to zero out + indices_to_zero = torch.randperm(num_valid_pixels)[:num_to_zero] + # Set selected valid indices to zero in the mask + sparsification_mask[ + valid_indices[0][indices_to_zero], + valid_indices[1][indices_to_zero], + valid_indices[2][indices_to_zero], + valid_indices[3][indices_to_zero], + ] = 0 + # Apply the mask on the depth + depth_for_curr_view_input = ( + depth_for_curr_view_input * sparsification_mask + ) + # Normalize the depth + scaled_depth_for_curr_view_input, depth_norm_factor = ( + normalize_depth_using_non_zero_pixels( + depth_for_curr_view_input, return_norm_factor=True + ) + ) + # Assign the depth and depth norm factor to the respective indices + depth_for_curr_view[per_sample_depth_input_mask_for_curr_view] = ( + scaled_depth_for_curr_view_input + ) + depth_norm_factor_for_curr_view[ + per_sample_depth_input_mask_for_curr_view + ] = depth_norm_factor + else: + per_sample_depth_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + # Append the depths, depth norm factor and metric scale mask for the current view + depth_list.append(depth_for_curr_view) + depth_norm_factors_list.append(depth_norm_factor_for_curr_view) + metric_scale_depth_mask_list.append(metric_scale_mask_for_curr_view) + + # Stack the depths for all the views and permute to (B * V, C, H, W) + depths = torch.cat(depth_list, dim=0) # (B * V, H, W, 1) + depths = apply_log_to_norm( + depths + ) # Scale logarithimically (norm is computed along last dim) + depths = depths.permute(0, 3, 1, 2).contiguous() # (B * V, 1, H, W) + # Encode the depths using the depth encoder + depth_features_across_views = self.depth_encoder( + ViTEncoderNonImageInput(data=depths) + ).features + # Zero out the depth features where the depth input mask is False + depth_features_across_views = ( + depth_features_across_views + * per_sample_depth_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + ) + + # Stack the depth norm factors for all the views + depth_norm_factors = torch.cat(depth_norm_factors_list, dim=0) # (B * V, ) + # Encode the depth norm factors using the log scale encoder for depth + log_depth_norm_factors = torch.log(depth_norm_factors + 1e-8) # (B * V, ) + depth_scale_features_across_views = self.depth_scale_encoder( + EncoderGlobalRepInput(data=log_depth_norm_factors.unsqueeze(-1)) + ).features + # Zero out the depth scale features where the depth input mask is False + depth_scale_features_across_views = ( + depth_scale_features_across_views + * per_sample_depth_input_mask.unsqueeze(-1) + ) + # Stack the metric scale mask for all the views + metric_scale_depth_mask = torch.cat( + metric_scale_depth_mask_list, dim=0 + ) # (B * V, ) + # Zero out the depth scale features where the metric scale mask is False + # Scale encoding is only provided for metric scale samples + depth_scale_features_across_views = ( + depth_scale_features_across_views * metric_scale_depth_mask.unsqueeze(-1) + ) + + # Fuse the depth features & depth scale features with the other encoder features + all_encoder_features_across_views = ( + all_encoder_features_across_views + + depth_features_across_views + + depth_scale_features_across_views.unsqueeze(-1).unsqueeze(-1) + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_cam_quats_and_trans( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ): + """ + Encode the camera quats and trans for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + pose_quats_across_views (torch.Tensor): Tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4) + pose_trans_across_views (torch.Tensor): Tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3) + per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Encode the pose quats + pose_quats_features_across_views = self.cam_rot_encoder( + EncoderGlobalRepInput(data=pose_quats_across_views) + ).features + # Zero out the pose quat features where the camera input mask is False + pose_quats_features_across_views = ( + pose_quats_features_across_views * per_sample_cam_input_mask.unsqueeze(-1) + ) + + # Get the metric scale mask for all samples + device = all_encoder_features_across_views.device + metric_scale_pose_trans_mask = torch.zeros( + (batch_size_per_view * num_views), dtype=torch.bool, device=device + ) + for view_idx in range(num_views): + if "is_metric_scale" in views[view_idx]: + # Get the metric scale mask for the input pose priors + metric_scale_mask = views[view_idx]["is_metric_scale"] + else: + metric_scale_mask = torch.zeros( + batch_size_per_view, dtype=torch.bool, device=device + ) + metric_scale_pose_trans_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] = metric_scale_mask + + # Turn off indication of metric scale samples based on the pose_scale_norm_all_prob + pose_norm_all_mask = ( + torch.rand(batch_size_per_view * num_views) + < self.geometric_input_config["pose_scale_norm_all_prob"] + ) + if pose_norm_all_mask.any(): + metric_scale_pose_trans_mask[pose_norm_all_mask] = False + + # Get the scale norm factor for all the samples and scale the pose translations + pose_trans_across_views = torch.split( + pose_trans_across_views, batch_size_per_view, dim=0 + ) # Split into num_views chunks + pose_trans_across_views = torch.stack( + pose_trans_across_views, dim=1 + ) # Stack the views along a new dimension (batch_size_per_view, num_views, 3) + scaled_pose_trans_across_views, pose_trans_norm_factors = ( + normalize_pose_translations( + pose_trans_across_views, return_norm_factor=True + ) + ) + + # Resize the pose translation back to (batch_size_per_view * num_views, 3) and extend the norm factor to (batch_size_per_view * num_views, 1) + scaled_pose_trans_across_views = scaled_pose_trans_across_views.unbind( + dim=1 + ) # Convert back to list of views, where each view has batch_size_per_view tensor + scaled_pose_trans_across_views = torch.cat( + scaled_pose_trans_across_views, dim=0 + ) # Concatenate back to (batch_size_per_view * num_views, 3) + pose_trans_norm_factors_across_views = pose_trans_norm_factors.unsqueeze( + -1 + ).repeat(num_views, 1) # (B, ) -> (B * V, 1) + + # Encode the pose trans + pose_trans_features_across_views = self.cam_trans_encoder( + EncoderGlobalRepInput(data=scaled_pose_trans_across_views) + ).features + # Zero out the pose trans features where the camera input mask is False + pose_trans_features_across_views = ( + pose_trans_features_across_views * per_sample_cam_input_mask.unsqueeze(-1) + ) + + # Encode the pose translation norm factors using the log scale encoder for pose trans + log_pose_trans_norm_factors_across_views = torch.log( + pose_trans_norm_factors_across_views + 1e-8 + ) + pose_trans_scale_features_across_views = self.cam_trans_scale_encoder( + EncoderGlobalRepInput(data=log_pose_trans_norm_factors_across_views) + ).features + # Zero out the pose trans scale features where the camera input mask is False + pose_trans_scale_features_across_views = ( + pose_trans_scale_features_across_views + * per_sample_cam_input_mask.unsqueeze(-1) + ) + # Zero out the pose trans scale features where the metric scale mask is False + # Scale encoding is only provided for metric scale samples + pose_trans_scale_features_across_views = ( + pose_trans_scale_features_across_views + * metric_scale_pose_trans_mask.unsqueeze(-1) + ) + + # Fuse the pose quat features, pose trans features, pose trans scale features and pose trans type PE features with the other encoder features + all_encoder_features_across_views = ( + all_encoder_features_across_views + + pose_quats_features_across_views.unsqueeze(-1).unsqueeze(-1) + + pose_trans_features_across_views.unsqueeze(-1).unsqueeze(-1) + + pose_trans_scale_features_across_views.unsqueeze(-1).unsqueeze(-1) + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_optional_geometric_inputs( + self, views, all_encoder_features_across_views_list + ): + """ + Encode all the input optional geometric modalities and fuses it with the image encoder features in a single forward pass. + Assumes all the input views have the same shape and batch size. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + all_encoder_features_across_views (List[torch.Tensor]): List of tensors containing the encoded image features for all N views. + + Returns: + List[torch.Tensor]: A list containing the encoded features for all N views. + """ + num_views = len(views) + batch_size_per_view, _, _, _ = views[0]["img"].shape + device = all_encoder_features_across_views_list[0].device + dtype = all_encoder_features_across_views_list[0].dtype + all_encoder_features_across_views = torch.cat( + all_encoder_features_across_views_list, dim=0 + ) + + # Get the overall input mask for all the views + overall_geometric_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["overall_prob"] + ) + overall_geometric_input_mask = overall_geometric_input_mask.repeat(num_views) + + # Get the per sample input mask after dropout + # Per sample input mask is in view-major order so that index v*B + b in each mask corresponds to sample b of view v: (B * V) + per_sample_geometric_input_mask = torch.rand( + batch_size_per_view * num_views, device=device + ) < (1 - self.geometric_input_config["dropout_prob"]) + per_sample_geometric_input_mask = ( + per_sample_geometric_input_mask & overall_geometric_input_mask + ) + + # Get the ray direction input mask + per_sample_ray_dirs_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["ray_dirs_prob"] + ) + per_sample_ray_dirs_input_mask = per_sample_ray_dirs_input_mask.repeat( + num_views + ) + per_sample_ray_dirs_input_mask = ( + per_sample_ray_dirs_input_mask & per_sample_geometric_input_mask + ) + + # Get the depth input mask + per_sample_depth_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["depth_prob"] + ) + per_sample_depth_input_mask = per_sample_depth_input_mask.repeat(num_views) + per_sample_depth_input_mask = ( + per_sample_depth_input_mask & per_sample_geometric_input_mask + ) + + # Get the camera input mask + per_sample_cam_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["cam_prob"] + ) + per_sample_cam_input_mask = per_sample_cam_input_mask.repeat(num_views) + per_sample_cam_input_mask = ( + per_sample_cam_input_mask & per_sample_geometric_input_mask + ) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + # Returned pose quats and trans represent identity pose for views/samples where the camera input mask is False + pose_quats_across_views, pose_trans_across_views, per_sample_cam_input_mask = ( + self._compute_pose_quats_and_trans_for_across_views_in_ref_view( + views, + num_views, + device, + dtype, + batch_size_per_view, + per_sample_cam_input_mask, + ) + ) + + # Encode the ray directions and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_ray_dirs( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_ray_dirs_input_mask, + ) + + # Encode the depths and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_depths( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_depth_input_mask, + ) + + # Encode the cam quat and trans and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_cam_quats_and_trans( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ) + + # Normalize the fused features (permute -> normalize -> permute) + all_encoder_features_across_views = all_encoder_features_across_views.permute( + 0, 2, 3, 1 + ).contiguous() + all_encoder_features_across_views = self.fusion_norm_layer( + all_encoder_features_across_views + ) + all_encoder_features_across_views = all_encoder_features_across_views.permute( + 0, 3, 1, 2 + ).contiguous() + + # Split the batched views into individual views + fused_all_encoder_features_across_views = ( + all_encoder_features_across_views.chunk(num_views, dim=0) + ) + + return fused_all_encoder_features_across_views + + def forward(self, views): + """ + Forward pass performing the following operations: + 1. Encodes the N input views (images). + 2. Encodes the optional geometric inputs (ray directions, depths, camera rotations, camera translations). + 3. Fuses the encoded features from the N input views and the optional geometric inputs using addition and normalization. + 4. Information sharing between the encoded features using a multi-view attention transformer. + 5. Passes the final features through the prediction heads. + 6. Returns the processed final outputs for N views. + + Assumption: + - All the input views have the same image shape. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). Input images must be normalized based on the data norm type of image encoder. + "data_norm_type" (list): [model.encoder.data_norm_type] + + Returns: + List[dict]: A list containing the final outputs for all N views. + """ + # Get input shape of the images, number of views, and batch size per view + batch_size_per_view, _, height, width = views[0]["img"].shape + img_shape = (int(height), int(width)) + num_views = len(views) + + # Run the encoder on all the input views + all_encoder_features_across_views = self._encode_n_views(views) + + # Encode the optional geometric inputs and fuse with the encoded features from the N input views + # Use high precision to prevent NaN values after layer norm in dense representation encoder (due to high variance in last dim of features) + with torch.autocast("cuda", enabled=False): + all_encoder_features_across_views = ( + self._encode_and_fuse_optional_geometric_inputs( + views, all_encoder_features_across_views + ) + ) + + # Combine all images into view-centric representation + info_sharing_input = MultiViewTransformerInput( + features=all_encoder_features_across_views + ) + if self.info_sharing_return_type == "no_intermediate_features": + final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input) + elif self.info_sharing_return_type == "intermediate_features": + ( + final_info_sharing_multi_view_feat, + intermediate_info_sharing_multi_view_feat, + ) = self.info_sharing(info_sharing_input) + + if self.pred_head_type == "linear": + # Stack the features for all views + dense_head_inputs = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + elif self.pred_head_type in ["dpt", "dpt+pose"]: + # Get the list of features for all views + dense_head_inputs_list = [] + if self.use_encoder_features_for_dpt: + # Stack all the image encoder features for all views + stacked_encoder_features = torch.cat( + all_encoder_features_across_views, dim=0 + ) + dense_head_inputs_list.append(stacked_encoder_features) + # Stack the first intermediate features for all views + stacked_intermediate_features_1 = torch.cat( + intermediate_info_sharing_multi_view_feat[0].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_1) + # Stack the second intermediate features for all views + stacked_intermediate_features_2 = torch.cat( + intermediate_info_sharing_multi_view_feat[1].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_2) + # Stack the last layer features for all views + stacked_final_features = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + dense_head_inputs_list.append(stacked_final_features) + else: + # Stack the first intermediate features for all views + stacked_intermediate_features_1 = torch.cat( + intermediate_info_sharing_multi_view_feat[0].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_1) + # Stack the second intermediate features for all views + stacked_intermediate_features_2 = torch.cat( + intermediate_info_sharing_multi_view_feat[1].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_2) + # Stack the third intermediate features for all views + stacked_intermediate_features_3 = torch.cat( + intermediate_info_sharing_multi_view_feat[2].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_3) + # Stack the last layer + stacked_final_features = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + dense_head_inputs_list.append(stacked_final_features) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + # Downstream task prediction + with torch.autocast("cuda", enabled=False): + # Run Prediction Heads & Post-Process Outputs + if self.pred_head_type == "linear": + dense_head_outputs = self.dense_head( + PredictionHeadInput(last_feature=dense_head_inputs) + ) + dense_final_outputs = self.dense_adaptor( + AdaptorInput( + adaptor_feature=dense_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + elif self.pred_head_type == "dpt": + dense_head_outputs = self.dense_head( + PredictionHeadLayeredInput( + list_features=dense_head_inputs_list, + target_output_shape=img_shape, + ) + ) + dense_final_outputs = self.dense_adaptor( + AdaptorInput( + adaptor_feature=dense_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + elif self.pred_head_type == "dpt+pose": + dense_head_outputs = self.dense_head( + PredictionHeadLayeredInput( + list_features=dense_head_inputs_list, + target_output_shape=img_shape, + ) + ) + dense_final_outputs = self.dense_adaptor( + AdaptorInput( + adaptor_feature=dense_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + pose_head_outputs = self.pose_head( + PredictionHeadInput(last_feature=dense_head_inputs_list[-1]) + ) + pose_final_outputs = self.pose_adaptor( + AdaptorInput( + adaptor_feature=pose_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + # Prepare the final scene representation for all views + if self.scene_rep_type in [ + "pointmap", + "pointmap+confidence", + "pointmap+mask", + "pointmap+confidence+mask", + ]: + output_pts3d = dense_final_outputs.value + # Reshape final scene representation to (B * V, H, W, C) + output_pts3d = output_pts3d.permute(0, 2, 3, 1).contiguous() + # Split the predicted pointmaps back to their respective views + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append({"pts3d": output_pts3d_per_view[i]}) + elif self.scene_rep_type in [ + "raymap+depth", + "raymap+depth+confidence", + "raymap+depth+mask", + "raymap+depth+confidence+mask", + ]: + # Reshape final scene representation to (B * V, H, W, C) + output_scene_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted ray origins, directions, and depths along rays + output_ray_origins, output_ray_directions, output_depth_along_ray = ( + output_scene_rep.split([3, 3, 1], dim=-1) + ) + # Get the predicted pointmaps + output_pts3d = ( + output_ray_origins + output_ray_directions * output_depth_along_ray + ) + # Split the predicted quantities back to their respective views + output_ray_origins_per_view = output_ray_origins.chunk(num_views, dim=0) + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i], + "ray_origins": output_ray_origins_per_view[i], + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i], + } + ) + elif self.scene_rep_type in [ + "raydirs+depth+pose", + "raydirs+depth+pose+confidence", + "raydirs+depth+pose+mask", + "raydirs+depth+pose+confidence+mask", + ]: + # Reshape output dense rep to (B * V, H, W, C) + output_dense_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted ray directions and depths along rays + output_ray_directions, output_depth_along_ray = output_dense_rep.split( + [3, 1], dim=-1 + ) + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the predicted pointmaps in world frame and camera frame + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + output_pts3d_cam = output_ray_directions * output_depth_along_ray + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i], + "pts3d_cam": output_pts3d_cam_per_view[i], + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i], + "cam_trans": output_cam_translations_per_view[i], + "cam_quats": output_cam_quats_per_view[i], + } + ) + elif self.scene_rep_type in [ + "campointmap+pose", + "campointmap+pose+confidence", + "campointmap+pose+mask", + "campointmap+pose+confidence+mask", + ]: + # Get the predicted camera frame pointmaps + output_pts3d_cam = dense_final_outputs.value + # Reshape final scene representation to (B * V, H, W, C) + output_pts3d_cam = output_pts3d_cam.permute(0, 2, 3, 1).contiguous() + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the ray directions and depths along rays + output_depth_along_ray = torch.norm( + output_pts3d_cam, dim=-1, keepdim=True + ) + output_ray_directions = output_pts3d_cam / output_depth_along_ray + # Get the predicted pointmaps in world frame + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i], + "pts3d_cam": output_pts3d_cam_per_view[i], + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i], + "cam_trans": output_cam_translations_per_view[i], + "cam_quats": output_cam_quats_per_view[i], + } + ) + elif self.scene_rep_type in [ + "pointmap+raydirs+depth+pose", + "pointmap+raydirs+depth+pose+confidence", + "pointmap+raydirs+depth+pose+mask", + "pointmap+raydirs+depth+pose+confidence+mask", + ]: + # Reshape final scene representation to (B * V, H, W, C) + output_dense_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted pointmaps, ray directions and depths along rays + output_pts3d, output_ray_directions, output_depth_along_ray = ( + output_dense_rep.split([3, 3, 1], dim=-1) + ) + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the predicted pointmaps in camera frame + output_pts3d_cam = output_ray_directions * output_depth_along_ray + # Replace the predicted world-frame pointmaps if required + if self.pred_head_config["adaptor_config"][ + "use_factored_predictions_for_global_pointmaps" + ]: + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i], + "pts3d_cam": output_pts3d_cam_per_view[i], + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i], + "cam_trans": output_cam_translations_per_view[i], + "cam_quats": output_cam_quats_per_view[i], + } + ) + else: + raise ValueError( + f"Invalid scene_rep_type: {self.scene_rep_type}. \ + Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \ + 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \ + 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \ + 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']" + ) + + # Get the output confidences for all views (if available) and add them to the result + if "confidence" in self.scene_rep_type: + output_confidences = dense_final_outputs.confidence + # Reshape confidences to (B * V, H, W) + output_confidences = ( + output_confidences.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + # Split the predicted confidences back to their respective views + output_confidences_per_view = output_confidences.chunk(num_views, dim=0) + # Add the confidences to the result + for i in range(num_views): + res[i]["conf"] = output_confidences_per_view[i] + + # Get the output masks (and logits) for all views (if available) and add them to the result + if "mask" in self.scene_rep_type: + # Get the output masks + output_masks = dense_final_outputs.mask + # Reshape masks to (B * V, H, W) + output_masks = output_masks.permute(0, 2, 3, 1).squeeze(-1).contiguous() + # Threshold the masks at 0.5 to get binary masks (0: ambiguous/invalid, 1: non-ambiguous/valid) + output_masks = output_masks > 0.5 + # Split the predicted masks back to their respective views + output_masks_per_view = output_masks.chunk(num_views, dim=0) + # Get the output mask logits (for loss) + output_mask_logits = dense_final_outputs.logits + # Reshape mask logits to (B * V, H, W) + output_mask_logits = ( + output_mask_logits.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + # Split the predicted mask logits back to their respective views + output_mask_logits_per_view = output_mask_logits.chunk(num_views, dim=0) + # Add the masks and logits to the result + for i in range(num_views): + res[i]["non_ambiguous_mask"] = output_masks_per_view[i] + res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i] + + return res diff --git a/mapanything/models/mapanything/model.py b/mapanything/models/mapanything/model.py new file mode 100644 index 0000000000000000000000000000000000000000..696038ec50f1b107c3ebb4b91cba26662514cffe --- /dev/null +++ b/mapanything/models/mapanything/model.py @@ -0,0 +1,2112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +MapAnything model class defined using UniCeption modules. +""" + +import warnings +from functools import partial +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin + +from mapanything.utils.geometry import ( + apply_log_to_norm, + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + normalize_depth_using_non_zero_pixels, + normalize_pose_translations, + transform_pose_using_quats_and_trans_2_to_1, +) +from mapanything.utils.inference import ( + postprocess_model_outputs_for_inference, + preprocess_input_views_for_inference, + validate_input_views_for_inference, +) +from uniception.models.encoders import ( + encoder_factory, + EncoderGlobalRepInput, + ViTEncoderInput, + ViTEncoderNonImageInput, +) +from uniception.models.info_sharing.alternating_attention_transformer import ( + MultiViewAlternatingAttentionTransformer, + MultiViewAlternatingAttentionTransformerIFR, +) +from uniception.models.info_sharing.base import MultiViewTransformerInput +from uniception.models.info_sharing.cross_attention_transformer import ( + MultiViewCrossAttentionTransformer, + MultiViewCrossAttentionTransformerIFR, +) +from uniception.models.info_sharing.global_attention_transformer import ( + MultiViewGlobalAttentionTransformer, + MultiViewGlobalAttentionTransformerIFR, +) +from uniception.models.prediction_heads.adaptors import ( + CamTranslationPlusQuatsAdaptor, + PointMapAdaptor, + PointMapPlusRayDirectionsPlusDepthAdaptor, + PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor, + PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor, + PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor, + PointMapWithConfidenceAdaptor, + PointMapWithConfidenceAndMaskAdaptor, + PointMapWithMaskAdaptor, + RayDirectionsPlusDepthAdaptor, + RayDirectionsPlusDepthWithConfidenceAdaptor, + RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor, + RayDirectionsPlusDepthWithMaskAdaptor, + RayMapPlusDepthAdaptor, + RayMapPlusDepthWithConfidenceAdaptor, + RayMapPlusDepthWithConfidenceAndMaskAdaptor, + RayMapPlusDepthWithMaskAdaptor, + ScaleAdaptor, +) +from uniception.models.prediction_heads.base import ( + AdaptorInput, + PredictionHeadInput, + PredictionHeadLayeredInput, + PredictionHeadTokenInput, +) +from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor +from uniception.models.prediction_heads.linear import LinearFeature +from uniception.models.prediction_heads.mlp_head import MLPHead +from uniception.models.prediction_heads.pose_head import PoseHead + +# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) +if hasattr(torch.backends.cuda, "matmul") and hasattr( + torch.backends.cuda.matmul, "allow_tf32" +): + torch.backends.cuda.matmul.allow_tf32 = True + + +class MapAnything(nn.Module, PyTorchModelHubMixin): + "Modular MapAnything model class that supports input of images & optional geometric modalities (multiple reconstruction tasks)." + + def __init__( + self, + name: str, + encoder_config: Dict, + info_sharing_config: Dict, + pred_head_config: Dict, + geometric_input_config: Dict, + fusion_norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial( + nn.LayerNorm, eps=1e-6 + ), + pretrained_checkpoint_path: str = None, + load_specific_pretrained_submodules: bool = False, + specific_pretrained_submodules: list = None, + torch_hub_force_reload: bool = False, + ): + """ + Multi-view model containing an image encoder fused with optional geometric modalities followed by a multi-view attention transformer and respective downstream heads. + The goal is to output scene representation. + The multi-view attention transformer also takes as input a scale token to predict the metric scaling factor for the predicted scene representation. + + Args: + name (str): Name of the model. + encoder_config (Dict): Configuration for the encoder. + info_sharing_config (Dict): Configuration for the multi-view attention transformer. + pred_head_config (Dict): Configuration for the prediction heads. + geometric_input_config (Dict): Configuration for the input of optional geometric modalities. + fusion_norm_layer (Union[Type[nn.Module], Callable[..., nn.Module]]): Normalization layer to use after fusion (addition) of encoder and geometric modalities. (default: partial(nn.LayerNorm, eps=1e-6)) + pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None) + load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False) + specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: None) + torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False) + """ + super().__init__() + + # Initialize the attributes + self.name = name + self.encoder_config = encoder_config + self.info_sharing_config = info_sharing_config + self.pred_head_config = pred_head_config + self.geometric_input_config = geometric_input_config + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.load_specific_pretrained_submodules = load_specific_pretrained_submodules + self.specific_pretrained_submodules = specific_pretrained_submodules + self.torch_hub_force_reload = torch_hub_force_reload + self.class_init_args = { + "name": self.name, + "encoder_config": self.encoder_config, + "info_sharing_config": self.info_sharing_config, + "pred_head_config": self.pred_head_config, + "geometric_input_config": self.geometric_input_config, + "pretrained_checkpoint_path": self.pretrained_checkpoint_path, + "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules, + "specific_pretrained_submodules": self.specific_pretrained_submodules, + "torch_hub_force_reload": self.torch_hub_force_reload, + } + + # Get relevant parameters from the configs + self.info_sharing_type = info_sharing_config["model_type"] + self.info_sharing_return_type = info_sharing_config["model_return_type"] + self.pred_head_type = pred_head_config["type"] + + # Initialize image encoder + if self.encoder_config["uses_torch_hub"]: + self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload + # Create a copy of the config before deleting the key to preserve it for serialization + encoder_config_copy = self.encoder_config.copy() + del encoder_config_copy["uses_torch_hub"] + self.encoder = encoder_factory(**encoder_config_copy) + + # Initialize the encoder for ray directions + ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"] + ray_dirs_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + ray_dirs_encoder_config["patch_size"] = self.encoder.patch_size + self.ray_dirs_encoder = encoder_factory(**ray_dirs_encoder_config) + + # Initialize the encoder for depth (normalized per view and values after normalization are scaled logarithmically) + depth_encoder_config = self.geometric_input_config["depth_encoder_config"] + depth_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + depth_encoder_config["patch_size"] = self.encoder.patch_size + self.depth_encoder = encoder_factory(**depth_encoder_config) + + # Initialize the encoder for log scale factor of depth + depth_scale_encoder_config = self.geometric_input_config["scale_encoder_config"] + depth_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.depth_scale_encoder = encoder_factory(**depth_scale_encoder_config) + + # Initialize the encoder for camera rotation + cam_rot_encoder_config = self.geometric_input_config["cam_rot_encoder_config"] + cam_rot_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_rot_encoder = encoder_factory(**cam_rot_encoder_config) + + # Initialize the encoder for camera translation (normalized across all provided camera translations) + cam_trans_encoder_config = self.geometric_input_config[ + "cam_trans_encoder_config" + ] + cam_trans_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_trans_encoder = encoder_factory(**cam_trans_encoder_config) + + # Initialize the encoder for log scale factor of camera translation + cam_trans_scale_encoder_config = self.geometric_input_config[ + "scale_encoder_config" + ] + cam_trans_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim + self.cam_trans_scale_encoder = encoder_factory(**cam_trans_scale_encoder_config) + + # Initialize the fusion norm layer + self.fusion_norm_layer = fusion_norm_layer(self.encoder.enc_embed_dim) + + # Initialize the Scale Token + # Used to scale the final scene predictions to metric scale + # During inference extended to (B, C, T), where T is the number of tokens (i.e., 1) + self.scale_token = nn.Parameter(torch.zeros(self.encoder.enc_embed_dim)) + torch.nn.init.trunc_normal_(self.scale_token, std=0.02) + + # Initialize the info sharing module (multi-view transformer) + self._initialize_info_sharing(info_sharing_config) + + # Initialize the prediction heads + self._initialize_prediction_heads(pred_head_config) + + # Initialize the final adaptors + self._initialize_adaptors(pred_head_config) + + # Load pretrained weights + self._load_pretrained_weights() + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def _initialize_info_sharing(self, info_sharing_config): + """ + Initialize the information sharing module based on the configuration. + + This method sets up the custom positional encoding if specified and initializes + the appropriate multi-view transformer based on the configuration type. + + Args: + info_sharing_config (Dict): Configuration for the multi-view attention transformer. + Should contain 'custom_positional_encoding', 'model_type', and 'model_return_type'. + + Returns: + None + + Raises: + ValueError: If invalid configuration options are provided. + """ + # Initialize Custom Positional Encoding if required + custom_positional_encoding = info_sharing_config["custom_positional_encoding"] + if custom_positional_encoding is not None: + if isinstance(custom_positional_encoding, str): + print( + f"Using custom positional encoding for multi-view attention transformer: {custom_positional_encoding}" + ) + raise ValueError( + f"Invalid custom_positional_encoding: {custom_positional_encoding}. None implemented." + ) + elif isinstance(custom_positional_encoding, Callable): + print( + "Using callable function as custom positional encoding for multi-view attention transformer." + ) + self.custom_positional_encoding = custom_positional_encoding + else: + self.custom_positional_encoding = None + + # Add dependencies to info_sharing_config + info_sharing_config["module_args"]["input_embed_dim"] = ( + self.encoder.enc_embed_dim + ) + info_sharing_config["module_args"]["custom_positional_encoding"] = ( + self.custom_positional_encoding + ) + + # Initialize Multi-View Transformer + if self.info_sharing_return_type == "no_intermediate_features": + # Returns only normalized last layer features + # Initialize multi-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformer( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + elif self.info_sharing_return_type == "intermediate_features": + # Returns intermediate features and normalized last layer features + # Initialize mulit-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + # Assess if the DPT needs to use encoder features + if len(self.info_sharing.indices) == 2: + self.use_encoder_features_for_dpt = True + elif len(self.info_sharing.indices) == 3: + self.use_encoder_features_for_dpt = False + else: + raise ValueError( + "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices." + ) + else: + raise ValueError( + f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']" + ) + + def _initialize_prediction_heads(self, pred_head_config): + """ + Initialize the prediction heads based on the prediction head configuration. + + This method configures and initializes the appropriate prediction heads based on the + specified prediction head type (linear, DPT, or DPT+pose). It sets up the necessary + dependencies and creates the required model components. + + Args: + pred_head_config (Dict): Configuration for the prediction heads. + + Returns: + None + + Raises: + ValueError: If an invalid pred_head_type is provided. + """ + # Add dependencies to prediction head config + pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size + if self.pred_head_type == "linear": + pred_head_config["feature_head"]["input_feature_dim"] = ( + self.info_sharing.dim + ) + elif "dpt" in self.pred_head_type: + # Add dependencies for DPT & Regressor head + if self.use_encoder_features_for_dpt: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.encoder.enc_embed_dim + ] + [self.info_sharing.dim] * 3 + else: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.info_sharing.dim + ] * 4 + pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[ + "feature_head" + ]["feature_dim"] + # Add dependencies for Pose head if required + if "pose" in self.pred_head_type: + pred_head_config["pose_head"]["patch_size"] = self.encoder.patch_size + pred_head_config["pose_head"]["input_feature_dim"] = ( + self.info_sharing.dim + ) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + pred_head_config["scale_head"]["input_feature_dim"] = self.info_sharing.dim + + # Initialize Prediction Heads + if self.pred_head_type == "linear": + # Initialize Dense Prediction Head for all views + self.dense_head = LinearFeature(**pred_head_config["feature_head"]) + elif "dpt" in self.pred_head_type: + # Initialize Dense Prediction Head for all views + self.dpt_feature_head = DPTFeature(**pred_head_config["feature_head"]) + self.dpt_regressor_head = DPTRegressionProcessor( + **pred_head_config["regressor_head"] + ) + self.dense_head = nn.Sequential( + self.dpt_feature_head, self.dpt_regressor_head + ) + # Initialize Pose Head for all views if required + if "pose" in self.pred_head_type: + self.pose_head = PoseHead(**pred_head_config["pose_head"]) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + self.scale_head = MLPHead(**pred_head_config["scale_head"]) + + def _initialize_adaptors(self, pred_head_config): + """ + Initialize the adaptors based on the prediction head configuration. + + This method sets up the appropriate adaptors for different scene representation types, + such as pointmaps, ray maps with depth, or ray directions with depth and pose. + + Args: + pred_head_config (Dict): Configuration for the prediction heads including adaptor type. + + Returns: + None + + Raises: + ValueError: If an invalid adaptor_type is provided. + AssertionError: If ray directions + depth + pose is used with an incompatible head type. + """ + if pred_head_config["adaptor_type"] == "pointmap": + self.dense_adaptor = PointMapAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "pointmap" + elif pred_head_config["adaptor_type"] == "pointmap+confidence": + self.dense_adaptor = PointMapWithConfidenceAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "pointmap+confidence" + elif pred_head_config["adaptor_type"] == "pointmap+mask": + self.dense_adaptor = PointMapWithMaskAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "pointmap+mask" + elif pred_head_config["adaptor_type"] == "pointmap+confidence+mask": + self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "pointmap+confidence+mask" + elif pred_head_config["adaptor_type"] == "raymap+depth": + self.dense_adaptor = RayMapPlusDepthAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "raymap+depth" + elif pred_head_config["adaptor_type"] == "raymap+depth+confidence": + self.dense_adaptor = RayMapPlusDepthWithConfidenceAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+confidence" + elif pred_head_config["adaptor_type"] == "raymap+depth+mask": + self.dense_adaptor = RayMapPlusDepthWithMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+mask" + elif pred_head_config["adaptor_type"] == "raymap+depth+confidence+mask": + self.dense_adaptor = RayMapPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["adaptor"] + ) + self.scene_rep_type = "raymap+depth+confidence+mask" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+confidence" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+mask" + elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence+mask": + assert self.pred_head_type == "dpt+pose", ( + "Ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "raydirs+depth+pose+confidence+mask" + elif pred_head_config["adaptor_type"] == "campointmap+pose": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapAdaptor(**pred_head_config["dpt_adaptor"]) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose" + elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+confidence" + elif pred_head_config["adaptor_type"] == "campointmap+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+mask" + elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence+mask": + assert self.pred_head_type == "dpt+pose", ( + "Camera pointmap + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "campointmap+pose+confidence+mask" + elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose": + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose" + elif ( + pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+confidence" + ): + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = ( + PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor( + **pred_head_config["dpt_adaptor"] + ) + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence" + elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+mask": + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+mask" + elif ( + pred_head_config["adaptor_type"] + == "pointmap+raydirs+depth+pose+confidence+mask" + ): + assert self.pred_head_type == "dpt+pose", ( + "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head." + ) + self.dense_adaptor = ( + PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor( + **pred_head_config["dpt_adaptor"] + ) + ) + self.pose_adaptor = CamTranslationPlusQuatsAdaptor( + **pred_head_config["pose_adaptor"] + ) + self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence+mask" + else: + raise ValueError( + f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. \ + Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \ + 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \ + 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \ + 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']" + ) + self.scale_adaptor = ScaleAdaptor(**pred_head_config["scale_adaptor"]) + + def _load_pretrained_weights(self): + """ + Load pretrained weights from a checkpoint file. + + If load_specific_pretrained_submodules is True, only loads weights for the specified submodules. + Otherwise, loads all weights from the checkpoint. + + Returns: + None + """ + if self.pretrained_checkpoint_path is not None: + if not self.load_specific_pretrained_submodules: + print( + f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + else: + print( + f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} for specific submodules: {self.specific_pretrained_submodules} ..." + ) + assert self.pred_head_type is not None, ( + "Specific submodules to load cannot be None." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + filtered_ckpt = {} + for ckpt_key, ckpt_value in ckpt["model"].items(): + for submodule in self.specific_pretrained_submodules: + if ckpt_key.startswith(submodule): + filtered_ckpt[ckpt_key] = ckpt_value + print(self.load_state_dict(filtered_ckpt, strict=False)) + + def _encode_n_views(self, views): + """ + Encode all the input views (batch of images) in a single forward pass. + Assumes all the input views have the same image shape, batch size, and data normalization type. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + + Returns: + List[torch.Tensor]: A list containing the encoded features for all N views. + """ + num_views = len(views) + data_norm_type = views[0]["data_norm_type"][0] + imgs_list = [view["img"] for view in views] + all_imgs_across_views = torch.cat(imgs_list, dim=0) + encoder_input = ViTEncoderInput( + image=all_imgs_across_views, data_norm_type=data_norm_type + ) + encoder_output = self.encoder(encoder_input) + all_encoder_features_across_views = encoder_output.features.chunk( + num_views, dim=0 + ) + + return all_encoder_features_across_views + + def _compute_pose_quats_and_trans_for_across_views_in_ref_view( + self, + views, + num_views, + device, + dtype, + batch_size_per_view, + per_sample_cam_input_mask, + ): + """ + Compute the pose quats and trans for all the views in the frame of the reference view 0. + Returns identity pose for views where the camera input mask is False or the pose is not provided. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + device (torch.device): Device to use for the computation. + dtype (torch.dtype): Data type to use for the computation. + per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask. + + Returns: + torch.Tensor: A tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4) + torch.Tensor: A tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3) + torch.Tensor: A tensor containing the per sample camera input mask. + """ + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + pose_quats_non_ref_views = [] + pose_trans_non_ref_views = [] + pose_quats_ref_view_0 = [] + pose_trans_ref_view_0 = [] + for view_idx in range(num_views): + per_sample_cam_input_mask_for_curr_view = per_sample_cam_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] + if ( + "camera_pose_quats" in views[view_idx] + and "camera_pose_trans" in views[view_idx] + and per_sample_cam_input_mask_for_curr_view.any() + ): + # Get the camera pose quats and trans for the current view + cam_pose_quats = views[view_idx]["camera_pose_quats"][ + per_sample_cam_input_mask_for_curr_view + ] + cam_pose_trans = views[view_idx]["camera_pose_trans"][ + per_sample_cam_input_mask_for_curr_view + ] + # Append to the list + pose_quats_non_ref_views.append(cam_pose_quats) + pose_trans_non_ref_views.append(cam_pose_trans) + # Get the camera pose quats and trans for the reference view 0 + cam_pose_quats = views[0]["camera_pose_quats"][ + per_sample_cam_input_mask_for_curr_view + ] + cam_pose_trans = views[0]["camera_pose_trans"][ + per_sample_cam_input_mask_for_curr_view + ] + # Append to the list + pose_quats_ref_view_0.append(cam_pose_quats) + pose_trans_ref_view_0.append(cam_pose_trans) + else: + per_sample_cam_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + + # Initialize the pose quats and trans for all views as identity + pose_quats_across_views = torch.tensor( + [0.0, 0.0, 0.0, 1.0], dtype=dtype, device=device + ).repeat(batch_size_per_view * num_views, 1) # (q_x, q_y, q_z, q_w) + pose_trans_across_views = torch.zeros( + (batch_size_per_view * num_views, 3), dtype=dtype, device=device + ) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + if len(pose_quats_non_ref_views) > 0: + # Stack the pose quats and trans for all the non-reference views and reference view 0 + pose_quats_non_ref_views = torch.cat(pose_quats_non_ref_views, dim=0) + pose_trans_non_ref_views = torch.cat(pose_trans_non_ref_views, dim=0) + pose_quats_ref_view_0 = torch.cat(pose_quats_ref_view_0, dim=0) + pose_trans_ref_view_0 = torch.cat(pose_trans_ref_view_0, dim=0) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + ( + pose_quats_non_ref_views_in_ref_view_0, + pose_trans_non_ref_views_in_ref_view_0, + ) = transform_pose_using_quats_and_trans_2_to_1( + pose_quats_ref_view_0, + pose_trans_ref_view_0, + pose_quats_non_ref_views, + pose_trans_non_ref_views, + ) + + # Update the pose quats and trans for all the non-reference views + pose_quats_across_views[per_sample_cam_input_mask] = ( + pose_quats_non_ref_views_in_ref_view_0.to(dtype=dtype) + ) + pose_trans_across_views[per_sample_cam_input_mask] = ( + pose_trans_non_ref_views_in_ref_view_0.to(dtype=dtype) + ) + + return ( + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ) + + def _encode_and_fuse_ray_dirs( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_ray_dirs_input_mask, + ): + """ + Encode the ray directions for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + per_sample_ray_dirs_input_mask (torch.Tensor): Tensor containing the per sample ray direction input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Get the height and width of the images + _, _, height, width = views[0]["img"].shape + + # Get the ray directions for all the views where info is provided and the ray direction input mask is True + ray_dirs_list = [] + for view_idx in range(num_views): + per_sample_ray_dirs_input_mask_for_curr_view = ( + per_sample_ray_dirs_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] + ) + ray_dirs_for_curr_view = torch.zeros( + (batch_size_per_view, height, width, 3), + dtype=all_encoder_features_across_views.dtype, + device=all_encoder_features_across_views.device, + ) + if ( + "ray_directions_cam" in views[view_idx] + and per_sample_ray_dirs_input_mask_for_curr_view.any() + ): + ray_dirs_for_curr_view[per_sample_ray_dirs_input_mask_for_curr_view] = ( + views[view_idx]["ray_directions_cam"][ + per_sample_ray_dirs_input_mask_for_curr_view + ] + ) + else: + per_sample_ray_dirs_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + ray_dirs_list.append(ray_dirs_for_curr_view) + + # Stack the ray directions for all the views and permute to (B * V, C, H, W) + ray_dirs = torch.cat(ray_dirs_list, dim=0) # (B * V, H, W, 3) + ray_dirs = ray_dirs.permute(0, 3, 1, 2).contiguous() # (B * V, 3, H, W) + + # Encode the ray directions + ray_dirs_features_across_views = self.ray_dirs_encoder( + ViTEncoderNonImageInput(data=ray_dirs) + ).features + + # Fuse the ray direction features with the other encoder features (zero out the features where the ray direction input mask is False) + ray_dirs_features_across_views = ( + ray_dirs_features_across_views + * per_sample_ray_dirs_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + ) + all_encoder_features_across_views = ( + all_encoder_features_across_views + ray_dirs_features_across_views + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_depths( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_depth_input_mask, + ): + """ + Encode the z depths for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + per_sample_depth_input_mask (torch.Tensor): Tensor containing the per sample depth input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Get the device and height and width of the images + device = all_encoder_features_across_views.device + _, _, height, width = views[0]["img"].shape + + # Decide to use randomly sampled sparse depth or dense depth + if torch.rand(1) < self.geometric_input_config["sparse_depth_prob"]: + use_sparse_depth = True + else: + use_sparse_depth = False + + # Get the depths for all the views + depth_list = [] + depth_norm_factors_list = [] + metric_scale_depth_mask_list = [] + for view_idx in range(num_views): + # Get the input mask for current view + per_sample_depth_input_mask_for_curr_view = per_sample_depth_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] + depth_for_curr_view = torch.zeros( + (batch_size_per_view, height, width, 1), + dtype=all_encoder_features_across_views.dtype, + device=device, + ) + depth_norm_factor_for_curr_view = torch.zeros( + (batch_size_per_view), + dtype=all_encoder_features_across_views.dtype, + device=device, + ) + metric_scale_mask_for_curr_view = torch.zeros( + (batch_size_per_view), + dtype=torch.bool, + device=device, + ) + if ( + "depth_along_ray" in views[view_idx] + ) and per_sample_depth_input_mask_for_curr_view.any(): + # Get depth for current view + depth_for_curr_view_input = views[view_idx]["depth_along_ray"][ + per_sample_depth_input_mask_for_curr_view + ] + # Get the metric scale mask + if "is_metric_scale" in views[view_idx]: + metric_scale_mask = views[view_idx]["is_metric_scale"][ + per_sample_depth_input_mask_for_curr_view + ] + else: + metric_scale_mask = torch.zeros( + depth_for_curr_view_input.shape[0], + dtype=torch.bool, + device=device, + ) + # Turn off indication of metric scale samples based on the depth_scale_norm_all_prob + depth_scale_norm_all_mask = ( + torch.rand(metric_scale_mask.shape[0]) + < self.geometric_input_config["depth_scale_norm_all_prob"] + ) + if depth_scale_norm_all_mask.any(): + metric_scale_mask[depth_scale_norm_all_mask] = False + # Assign the metric scale mask to the respective indices + metric_scale_mask_for_curr_view[ + per_sample_depth_input_mask_for_curr_view + ] = metric_scale_mask + # Sparsely sample the depth if required + if use_sparse_depth: + # Create a mask of ones + sparsification_mask = torch.ones_like( + depth_for_curr_view_input, device=device + ) + # Create a mask for valid pixels (depth > 0) + valid_pixel_mask = depth_for_curr_view_input > 0 + # Calculate the number of valid pixels + num_valid_pixels = valid_pixel_mask.sum().item() + # Calculate the number of valid pixels to set to zero + num_to_zero = int( + num_valid_pixels + * self.geometric_input_config["sparsification_removal_percent"] + ) + if num_to_zero > 0: + # Get the indices of valid pixels + valid_indices = valid_pixel_mask.nonzero(as_tuple=True) + # Randomly select indices to zero out + indices_to_zero = torch.randperm(num_valid_pixels)[:num_to_zero] + # Set selected valid indices to zero in the mask + sparsification_mask[ + valid_indices[0][indices_to_zero], + valid_indices[1][indices_to_zero], + valid_indices[2][indices_to_zero], + valid_indices[3][indices_to_zero], + ] = 0 + # Apply the mask on the depth + depth_for_curr_view_input = ( + depth_for_curr_view_input * sparsification_mask + ) + # Normalize the depth + scaled_depth_for_curr_view_input, depth_norm_factor = ( + normalize_depth_using_non_zero_pixels( + depth_for_curr_view_input, return_norm_factor=True + ) + ) + # Assign the depth and depth norm factor to the respective indices + depth_for_curr_view[per_sample_depth_input_mask_for_curr_view] = ( + scaled_depth_for_curr_view_input + ) + depth_norm_factor_for_curr_view[ + per_sample_depth_input_mask_for_curr_view + ] = depth_norm_factor + else: + per_sample_depth_input_mask[ + view_idx * batch_size_per_view : (view_idx + 1) + * batch_size_per_view + ] = False + # Append the depths, depth norm factor and metric scale mask for the current view + depth_list.append(depth_for_curr_view) + depth_norm_factors_list.append(depth_norm_factor_for_curr_view) + metric_scale_depth_mask_list.append(metric_scale_mask_for_curr_view) + + # Stack the depths for all the views and permute to (B * V, C, H, W) + depths = torch.cat(depth_list, dim=0) # (B * V, H, W, 1) + depths = apply_log_to_norm( + depths + ) # Scale logarithimically (norm is computed along last dim) + depths = depths.permute(0, 3, 1, 2).contiguous() # (B * V, 1, H, W) + # Encode the depths using the depth encoder + depth_features_across_views = self.depth_encoder( + ViTEncoderNonImageInput(data=depths) + ).features + # Zero out the depth features where the depth input mask is False + depth_features_across_views = ( + depth_features_across_views + * per_sample_depth_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + ) + + # Stack the depth norm factors for all the views + depth_norm_factors = torch.cat(depth_norm_factors_list, dim=0) # (B * V, ) + # Encode the depth norm factors using the log scale encoder for depth + log_depth_norm_factors = torch.log(depth_norm_factors + 1e-8) # (B * V, ) + depth_scale_features_across_views = self.depth_scale_encoder( + EncoderGlobalRepInput(data=log_depth_norm_factors.unsqueeze(-1)) + ).features + # Zero out the depth scale features where the depth input mask is False + depth_scale_features_across_views = ( + depth_scale_features_across_views + * per_sample_depth_input_mask.unsqueeze(-1) + ) + # Stack the metric scale mask for all the views + metric_scale_depth_mask = torch.cat( + metric_scale_depth_mask_list, dim=0 + ) # (B * V, ) + # Zero out the depth scale features where the metric scale mask is False + # Scale encoding is only provided for metric scale samples + depth_scale_features_across_views = ( + depth_scale_features_across_views * metric_scale_depth_mask.unsqueeze(-1) + ) + + # Fuse the depth features & depth scale features with the other encoder features + all_encoder_features_across_views = ( + all_encoder_features_across_views + + depth_features_across_views + + depth_scale_features_across_views.unsqueeze(-1).unsqueeze(-1) + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_cam_quats_and_trans( + self, + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ): + """ + Encode the camera quats and trans for all the views and fuse it with the other encoder features in a single forward pass. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + num_views (int): Number of views. + batch_size_per_view (int): Batch size per view. + all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views. + pose_quats_across_views (torch.Tensor): Tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4) + pose_trans_across_views (torch.Tensor): Tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3) + per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask. + + Returns: + torch.Tensor: A tensor containing the encoded features for all the views. + """ + # Encode the pose quats + pose_quats_features_across_views = self.cam_rot_encoder( + EncoderGlobalRepInput(data=pose_quats_across_views) + ).features + # Zero out the pose quat features where the camera input mask is False + pose_quats_features_across_views = ( + pose_quats_features_across_views * per_sample_cam_input_mask.unsqueeze(-1) + ) + + # Get the metric scale mask for all samples + device = all_encoder_features_across_views.device + metric_scale_pose_trans_mask = torch.zeros( + (batch_size_per_view * num_views), dtype=torch.bool, device=device + ) + for view_idx in range(num_views): + if "is_metric_scale" in views[view_idx]: + # Get the metric scale mask for the input pose priors + metric_scale_mask = views[view_idx]["is_metric_scale"] + else: + metric_scale_mask = torch.zeros( + batch_size_per_view, dtype=torch.bool, device=device + ) + metric_scale_pose_trans_mask[ + view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view + ] = metric_scale_mask + + # Turn off indication of metric scale samples based on the pose_scale_norm_all_prob + pose_norm_all_mask = ( + torch.rand(batch_size_per_view * num_views) + < self.geometric_input_config["pose_scale_norm_all_prob"] + ) + if pose_norm_all_mask.any(): + metric_scale_pose_trans_mask[pose_norm_all_mask] = False + + # Get the scale norm factor for all the samples and scale the pose translations + pose_trans_across_views = torch.split( + pose_trans_across_views, batch_size_per_view, dim=0 + ) # Split into num_views chunks + pose_trans_across_views = torch.stack( + pose_trans_across_views, dim=1 + ) # Stack the views along a new dimension (batch_size_per_view, num_views, 3) + scaled_pose_trans_across_views, pose_trans_norm_factors = ( + normalize_pose_translations( + pose_trans_across_views, return_norm_factor=True + ) + ) + + # Resize the pose translation back to (batch_size_per_view * num_views, 3) and extend the norm factor to (batch_size_per_view * num_views, 1) + scaled_pose_trans_across_views = scaled_pose_trans_across_views.unbind( + dim=1 + ) # Convert back to list of views, where each view has batch_size_per_view tensor + scaled_pose_trans_across_views = torch.cat( + scaled_pose_trans_across_views, dim=0 + ) # Concatenate back to (batch_size_per_view * num_views, 3) + pose_trans_norm_factors_across_views = pose_trans_norm_factors.unsqueeze( + -1 + ).repeat(num_views, 1) # (B, ) -> (B * V, 1) + + # Encode the pose trans + pose_trans_features_across_views = self.cam_trans_encoder( + EncoderGlobalRepInput(data=scaled_pose_trans_across_views) + ).features + # Zero out the pose trans features where the camera input mask is False + pose_trans_features_across_views = ( + pose_trans_features_across_views * per_sample_cam_input_mask.unsqueeze(-1) + ) + + # Encode the pose translation norm factors using the log scale encoder for pose trans + log_pose_trans_norm_factors_across_views = torch.log( + pose_trans_norm_factors_across_views + 1e-8 + ) + pose_trans_scale_features_across_views = self.cam_trans_scale_encoder( + EncoderGlobalRepInput(data=log_pose_trans_norm_factors_across_views) + ).features + # Zero out the pose trans scale features where the camera input mask is False + pose_trans_scale_features_across_views = ( + pose_trans_scale_features_across_views + * per_sample_cam_input_mask.unsqueeze(-1) + ) + # Zero out the pose trans scale features where the metric scale mask is False + # Scale encoding is only provided for metric scale samples + pose_trans_scale_features_across_views = ( + pose_trans_scale_features_across_views + * metric_scale_pose_trans_mask.unsqueeze(-1) + ) + + # Fuse the pose quat features, pose trans features, pose trans scale features and pose trans type PE features with the other encoder features + all_encoder_features_across_views = ( + all_encoder_features_across_views + + pose_quats_features_across_views.unsqueeze(-1).unsqueeze(-1) + + pose_trans_features_across_views.unsqueeze(-1).unsqueeze(-1) + + pose_trans_scale_features_across_views.unsqueeze(-1).unsqueeze(-1) + ) + + return all_encoder_features_across_views + + def _encode_and_fuse_optional_geometric_inputs( + self, views, all_encoder_features_across_views_list + ): + """ + Encode all the input optional geometric modalities and fuses it with the image encoder features in a single forward pass. + Assumes all the input views have the same shape and batch size. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + all_encoder_features_across_views (List[torch.Tensor]): List of tensors containing the encoded image features for all N views. + + Returns: + List[torch.Tensor]: A list containing the encoded features for all N views. + """ + num_views = len(views) + batch_size_per_view, _, _, _ = views[0]["img"].shape + device = all_encoder_features_across_views_list[0].device + dtype = all_encoder_features_across_views_list[0].dtype + all_encoder_features_across_views = torch.cat( + all_encoder_features_across_views_list, dim=0 + ) + + # Get the overall input mask for all the views + overall_geometric_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["overall_prob"] + ) + overall_geometric_input_mask = overall_geometric_input_mask.repeat(num_views) + + # Get the per sample input mask after dropout + # Per sample input mask is in view-major order so that index v*B + b in each mask corresponds to sample b of view v: (B * V) + per_sample_geometric_input_mask = torch.rand( + batch_size_per_view * num_views, device=device + ) < (1 - self.geometric_input_config["dropout_prob"]) + per_sample_geometric_input_mask = ( + per_sample_geometric_input_mask & overall_geometric_input_mask + ) + + # Get the ray direction input mask + per_sample_ray_dirs_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["ray_dirs_prob"] + ) + per_sample_ray_dirs_input_mask = per_sample_ray_dirs_input_mask.repeat( + num_views + ) + per_sample_ray_dirs_input_mask = ( + per_sample_ray_dirs_input_mask & per_sample_geometric_input_mask + ) + + # Get the depth input mask + per_sample_depth_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["depth_prob"] + ) + per_sample_depth_input_mask = per_sample_depth_input_mask.repeat(num_views) + per_sample_depth_input_mask = ( + per_sample_depth_input_mask & per_sample_geometric_input_mask + ) + + # Get the camera input mask + per_sample_cam_input_mask = ( + torch.rand(batch_size_per_view, device=device) + < self.geometric_input_config["cam_prob"] + ) + per_sample_cam_input_mask = per_sample_cam_input_mask.repeat(num_views) + per_sample_cam_input_mask = ( + per_sample_cam_input_mask & per_sample_geometric_input_mask + ) + + # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0 + # Returned pose quats and trans represent identity pose for views/samples where the camera input mask is False + pose_quats_across_views, pose_trans_across_views, per_sample_cam_input_mask = ( + self._compute_pose_quats_and_trans_for_across_views_in_ref_view( + views, + num_views, + device, + dtype, + batch_size_per_view, + per_sample_cam_input_mask, + ) + ) + + # Encode the ray directions and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_ray_dirs( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_ray_dirs_input_mask, + ) + + # Encode the depths and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_depths( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + per_sample_depth_input_mask, + ) + + # Encode the cam quat and trans and fuse with the image encoder features + all_encoder_features_across_views = self._encode_and_fuse_cam_quats_and_trans( + views, + num_views, + batch_size_per_view, + all_encoder_features_across_views, + pose_quats_across_views, + pose_trans_across_views, + per_sample_cam_input_mask, + ) + + # Normalize the fused features (permute -> normalize -> permute) + all_encoder_features_across_views = all_encoder_features_across_views.permute( + 0, 2, 3, 1 + ).contiguous() + all_encoder_features_across_views = self.fusion_norm_layer( + all_encoder_features_across_views + ) + all_encoder_features_across_views = all_encoder_features_across_views.permute( + 0, 3, 1, 2 + ).contiguous() + + # Split the batched views into individual views + fused_all_encoder_features_across_views = ( + all_encoder_features_across_views.chunk(num_views, dim=0) + ) + + return fused_all_encoder_features_across_views + + def _compute_adaptive_minibatch_size( + self, + memory_safety_factor: float = 0.95, + ) -> int: + """ + Compute adaptive minibatch size based on available PyTorch memory. + + Args: + memory_safety_factor: Safety factor to avoid OOM (0.95 = use 95% of available memory) + + Returns: + Computed minibatch size + """ + device = self.device + + if device.type == "cuda": + # Get available GPU memory + torch.cuda.empty_cache() + available_memory = torch.cuda.mem_get_info()[0] # Free memory in bytes + usable_memory = ( + available_memory * memory_safety_factor + ) # Use safety factor to avoid OOM + else: + # For non-CUDA devices, use conservative default + print( + "Non-CUDA device detected. Using conservative default minibatch size of 1 for memory efficient dense prediction head inference." + ) + return 1 + + # Determine minibatch size based on available memory + max_estimated_memory_per_sample = ( + 680 * 1024 * 1024 + ) # 680 MB per sample (upper bound profiling using a 518 x 518 input) + computed_minibatch_size = int(usable_memory / max_estimated_memory_per_sample) + if computed_minibatch_size < 1: + computed_minibatch_size = 1 + + return computed_minibatch_size + + def downstream_dense_head( + self, + dense_head_inputs: Union[torch.Tensor, List[torch.Tensor]], + img_shape: Tuple[int, int], + ): + """ + Run the downstream dense prediction head + """ + if self.pred_head_type == "linear": + dense_head_outputs = self.dense_head( + PredictionHeadInput(last_feature=dense_head_inputs) + ) + dense_final_outputs = self.dense_adaptor( + AdaptorInput( + adaptor_feature=dense_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + elif self.pred_head_type in ["dpt", "dpt+pose"]: + dense_head_outputs = self.dense_head( + PredictionHeadLayeredInput( + list_features=dense_head_inputs, + target_output_shape=img_shape, + ) + ) + dense_final_outputs = self.dense_adaptor( + AdaptorInput( + adaptor_feature=dense_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + return dense_final_outputs + + def downstream_head( + self, + dense_head_inputs: Union[torch.Tensor, List[torch.Tensor]], + scale_head_inputs: torch.Tensor, + img_shape: Tuple[int, int], + memory_efficient_inference: bool = False, + ): + """ + Run Prediction Heads & Post-Process Outputs + """ + # Get device + device = self.device + + # Use mini-batch inference to run the dense prediction head (the memory bottleneck) + # This saves memory and is slower than running the dense prediction head in one go + if memory_efficient_inference: + # Obtain the batch size of the dense head inputs + if self.pred_head_type == "linear": + batch_size = dense_head_inputs.shape[0] + elif self.pred_head_type in ["dpt", "dpt+pose"]: + batch_size = dense_head_inputs[0].shape[0] + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + # Compute the mini batch size and number of mini batches adaptively based on available memory + minibatch = self._compute_adaptive_minibatch_size() + num_batches = (batch_size + minibatch - 1) // minibatch + + # Run prediction for each mini-batch + dense_final_outputs_list = [] + pose_final_outputs_list = [] if self.pred_head_type == "dpt+pose" else None + for batch_idx in range(num_batches): + start_idx = batch_idx * minibatch + end_idx = min((batch_idx + 1) * minibatch, batch_size) + + # Get the inputs for the current mini-batch + if self.pred_head_type == "linear": + dense_head_inputs_batch = dense_head_inputs[start_idx:end_idx] + elif self.pred_head_type in ["dpt", "dpt+pose"]: + dense_head_inputs_batch = [ + x[start_idx:end_idx] for x in dense_head_inputs + ] + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + # Dense prediction (mini-batched) + dense_final_outputs_batch = self.downstream_dense_head( + dense_head_inputs_batch, img_shape + ) + dense_final_outputs_list.append(dense_final_outputs_batch) + + # Pose prediction (mini-batched) + if self.pred_head_type == "dpt+pose": + pose_head_inputs_batch = dense_head_inputs[-1][start_idx:end_idx] + pose_head_outputs_batch = self.pose_head( + PredictionHeadInput(last_feature=pose_head_inputs_batch) + ) + pose_final_outputs_batch = self.pose_adaptor( + AdaptorInput( + adaptor_feature=pose_head_outputs_batch.decoded_channels, + output_shape_hw=img_shape, + ) + ) + pose_final_outputs_list.append(pose_final_outputs_batch) + + # Concatenate the dense prediction head outputs from all mini-batches + available_keys = dense_final_outputs_batch.__dict__.keys() + dense_pred_data_dict = { + key: torch.cat( + [getattr(output, key) for output in dense_final_outputs_list], dim=0 + ) + for key in available_keys + } + dense_final_outputs = dense_final_outputs_batch.__class__( + **dense_pred_data_dict + ) + + # Concatenate the pose prediction head outputs from all mini-batches + pose_final_outputs = None + if self.pred_head_type == "dpt+pose": + available_keys = pose_final_outputs_batch.__dict__.keys() + pose_pred_data_dict = { + key: torch.cat( + [getattr(output, key) for output in pose_final_outputs_list], + dim=0, + ) + for key in available_keys + } + pose_final_outputs = pose_final_outputs_batch.__class__( + **pose_pred_data_dict + ) + + # Clear CUDA cache for better memory efficiency + if device.type == "cuda": + torch.cuda.empty_cache() + else: + # Run prediction for all (batch_size * num_views) in one go + # Dense prediction + dense_final_outputs = self.downstream_dense_head( + dense_head_inputs, img_shape + ) + + # Pose prediction + pose_final_outputs = None + if self.pred_head_type == "dpt+pose": + pose_head_outputs = self.pose_head( + PredictionHeadInput(last_feature=dense_head_inputs[-1]) + ) + pose_final_outputs = self.pose_adaptor( + AdaptorInput( + adaptor_feature=pose_head_outputs.decoded_channels, + output_shape_hw=img_shape, + ) + ) + + # Scale prediction is lightweight, so we can run it in one go + scale_head_output = self.scale_head( + PredictionHeadTokenInput(last_feature=scale_head_inputs) + ) + scale_final_output = self.scale_adaptor( + AdaptorInput( + adaptor_feature=scale_head_output.decoded_channels, + output_shape_hw=img_shape, + ) + ) + scale_final_output = scale_final_output.value.squeeze(-1) # (B, 1, 1) -> (B, 1) + + # Clear CUDA cache for better memory efficiency + if memory_efficient_inference and device.type == "cuda": + torch.cuda.empty_cache() + + return dense_final_outputs, pose_final_outputs, scale_final_output + + def forward(self, views, memory_efficient_inference=False): + """ + Forward pass performing the following operations: + 1. Encodes the N input views (images). + 2. Encodes the optional geometric inputs (ray directions, depths, camera rotations, camera translations). + 3. Fuses the encoded features from the N input views and the optional geometric inputs using addition and normalization. + 4. Information sharing across the encoded features and a scale token using a multi-view attention transformer. + 5. Passes the final features from transformer through the prediction heads. + 6. Returns the processed final outputs for N views. + + Assumption: + - All the input views and dense geometric inputs have the same image shape. + + Args: + views (List[dict]): List of dictionaries containing the input views' images and instance information. + Each dictionary should contain the following keys: + "img" (tensor): Image tensor of shape (B, C, H, W). Input images must be normalized based on the data norm type of image encoder. + "data_norm_type" (list): [model.encoder.data_norm_type] + Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs: + "ray_directions_cam" (tensor): Ray directions in the local camera frame. Tensor of shape (B, H, W, 3). + "depth_along_ray" (tensor): Depth along the ray. Tensor of shape (B, H, W, 1). + "camera_pose_quats" (tensor): Camera pose quaternions. Tensor of shape (B, 4). Camera pose is opencv (RDF) cam2world transformation. + "camera_pose_trans" (tensor): Camera pose translations. Tensor of shape (B, 3). Camera pose is opencv (RDF) cam2world transformation. + "is_metric_scale" (tensor): Boolean tensor indicating whether the geometric inputs are in metric scale or not. Tensor of shape (B, 1). + memory_efficient_inference (bool): Whether to use memory efficient inference or not. This runs the dense prediction head (the memory bottleneck) in a memory efficient manner. Default is False. + + Returns: + List[dict]: A list containing the final outputs for all N views. + """ + # Get input shape of the images, number of views, and batch size per view + batch_size_per_view, _, height, width = views[0]["img"].shape + img_shape = (int(height), int(width)) + num_views = len(views) + + # Run the image encoder on all the input views + all_encoder_features_across_views = self._encode_n_views(views) + + # Encode the optional geometric inputs and fuse with the encoded features from the N input views + # Use high precision to prevent NaN values after layer norm in dense representation encoder (due to high variance in last dim of features) + with torch.autocast("cuda", enabled=False): + all_encoder_features_across_views = ( + self._encode_and_fuse_optional_geometric_inputs( + views, all_encoder_features_across_views + ) + ) + + # Expand the scale token to match the batch size + input_scale_token = ( + self.scale_token.unsqueeze(0) + .unsqueeze(-1) + .repeat(batch_size_per_view, 1, 1) + ) # (B, C, 1) + + # Combine all images into view-centric representation + # Output is a list containing the encoded features for all N views after information sharing. + info_sharing_input = MultiViewTransformerInput( + features=all_encoder_features_across_views, + additional_input_tokens=input_scale_token, + ) + if self.info_sharing_return_type == "no_intermediate_features": + final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input) + elif self.info_sharing_return_type == "intermediate_features": + ( + final_info_sharing_multi_view_feat, + intermediate_info_sharing_multi_view_feat, + ) = self.info_sharing(info_sharing_input) + + if self.pred_head_type == "linear": + # Stack the features for all views + dense_head_inputs = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + elif self.pred_head_type in ["dpt", "dpt+pose"]: + # Get the list of features for all views + dense_head_inputs_list = [] + if self.use_encoder_features_for_dpt: + # Stack all the image encoder features for all views + stacked_encoder_features = torch.cat( + all_encoder_features_across_views, dim=0 + ) + dense_head_inputs_list.append(stacked_encoder_features) + # Stack the first intermediate features for all views + stacked_intermediate_features_1 = torch.cat( + intermediate_info_sharing_multi_view_feat[0].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_1) + # Stack the second intermediate features for all views + stacked_intermediate_features_2 = torch.cat( + intermediate_info_sharing_multi_view_feat[1].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_2) + # Stack the last layer features for all views + stacked_final_features = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + dense_head_inputs_list.append(stacked_final_features) + else: + # Stack the first intermediate features for all views + stacked_intermediate_features_1 = torch.cat( + intermediate_info_sharing_multi_view_feat[0].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_1) + # Stack the second intermediate features for all views + stacked_intermediate_features_2 = torch.cat( + intermediate_info_sharing_multi_view_feat[1].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_2) + # Stack the third intermediate features for all views + stacked_intermediate_features_3 = torch.cat( + intermediate_info_sharing_multi_view_feat[2].features, dim=0 + ) + dense_head_inputs_list.append(stacked_intermediate_features_3) + # Stack the last layer + stacked_final_features = torch.cat( + final_info_sharing_multi_view_feat.features, dim=0 + ) + dense_head_inputs_list.append(stacked_final_features) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']" + ) + + with torch.autocast("cuda", enabled=False): + # Prepare inputs for the downstream heads + if self.pred_head_type == "linear": + dense_head_inputs = dense_head_inputs + elif self.pred_head_type in ["dpt", "dpt+pose"]: + dense_head_inputs = dense_head_inputs_list + scale_head_inputs = ( + final_info_sharing_multi_view_feat.additional_token_features + ) + + # Run the downstream heads + dense_final_outputs, pose_final_outputs, scale_final_output = ( + self.downstream_head( + dense_head_inputs=dense_head_inputs, + scale_head_inputs=scale_head_inputs, + img_shape=img_shape, + memory_efficient_inference=memory_efficient_inference, + ) + ) + + # Prepare the final scene representation for all views + if self.scene_rep_type in [ + "pointmap", + "pointmap+confidence", + "pointmap+mask", + "pointmap+confidence+mask", + ]: + output_pts3d = dense_final_outputs.value + # Reshape final scene representation to (B * V, H, W, C) + output_pts3d = output_pts3d.permute(0, 2, 3, 1).contiguous() + # Split the predicted pointmaps back to their respective views + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "metric_scaling_factor": scale_final_output, + } + ) + elif self.scene_rep_type in [ + "raymap+depth", + "raymap+depth+confidence", + "raymap+depth+mask", + "raymap+depth+confidence+mask", + ]: + # Reshape final scene representation to (B * V, H, W, C) + output_scene_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted ray origins, directions, and depths along rays + output_ray_origins, output_ray_directions, output_depth_along_ray = ( + output_scene_rep.split([3, 3, 1], dim=-1) + ) + # Get the predicted pointmaps + output_pts3d = ( + output_ray_origins + output_ray_directions * output_depth_along_ray + ) + # Split the predicted quantities back to their respective views + output_ray_origins_per_view = output_ray_origins.chunk(num_views, dim=0) + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "ray_origins": output_ray_origins_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "metric_scaling_factor": scale_final_output, + } + ) + elif self.scene_rep_type in [ + "raydirs+depth+pose", + "raydirs+depth+pose+confidence", + "raydirs+depth+pose+mask", + "raydirs+depth+pose+confidence+mask", + ]: + # Reshape output dense rep to (B * V, H, W, C) + output_dense_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted ray directions and depths along rays + output_ray_directions, output_depth_along_ray = output_dense_rep.split( + [3, 1], dim=-1 + ) + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the predicted pointmaps in world frame and camera frame + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + output_pts3d_cam = output_ray_directions * output_depth_along_ray + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "pts3d_cam": output_pts3d_cam_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "cam_trans": output_cam_translations_per_view[i] + * scale_final_output, + "cam_quats": output_cam_quats_per_view[i], + "metric_scaling_factor": scale_final_output, + } + ) + elif self.scene_rep_type in [ + "campointmap+pose", + "campointmap+pose+confidence", + "campointmap+pose+mask", + "campointmap+pose+confidence+mask", + ]: + # Get the predicted camera frame pointmaps + output_pts3d_cam = dense_final_outputs.value + # Reshape final scene representation to (B * V, H, W, C) + output_pts3d_cam = output_pts3d_cam.permute(0, 2, 3, 1).contiguous() + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the ray directions and depths along rays + output_depth_along_ray = torch.norm( + output_pts3d_cam, dim=-1, keepdim=True + ) + output_ray_directions = output_pts3d_cam / output_depth_along_ray + # Get the predicted pointmaps in world frame + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "pts3d_cam": output_pts3d_cam_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "cam_trans": output_cam_translations_per_view[i] + * scale_final_output, + "cam_quats": output_cam_quats_per_view[i], + "metric_scaling_factor": scale_final_output, + } + ) + elif self.scene_rep_type in [ + "pointmap+raydirs+depth+pose", + "pointmap+raydirs+depth+pose+confidence", + "pointmap+raydirs+depth+pose+mask", + "pointmap+raydirs+depth+pose+confidence+mask", + ]: + # Reshape final scene representation to (B * V, H, W, C) + output_dense_rep = dense_final_outputs.value.permute( + 0, 2, 3, 1 + ).contiguous() + # Get the predicted pointmaps, ray directions and depths along rays + output_pts3d, output_ray_directions, output_depth_along_ray = ( + output_dense_rep.split([3, 3, 1], dim=-1) + ) + # Get the predicted camera translations and quaternions + output_cam_translations, output_cam_quats = ( + pose_final_outputs.value.split([3, 4], dim=-1) + ) + # Get the predicted pointmaps in camera frame + output_pts3d_cam = output_ray_directions * output_depth_along_ray + # Replace the predicted world-frame pointmaps if required + if self.pred_head_config["adaptor_config"][ + "use_factored_predictions_for_global_pointmaps" + ]: + output_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + output_ray_directions, + output_depth_along_ray, + output_cam_translations, + output_cam_quats, + ) + ) + # Split the predicted quantities back to their respective views + output_ray_directions_per_view = output_ray_directions.chunk( + num_views, dim=0 + ) + output_depth_along_ray_per_view = output_depth_along_ray.chunk( + num_views, dim=0 + ) + output_cam_translations_per_view = output_cam_translations.chunk( + num_views, dim=0 + ) + output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0) + output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0) + output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0) + # Pack the output as a list of dictionaries + res = [] + for i in range(num_views): + res.append( + { + "pts3d": output_pts3d_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "pts3d_cam": output_pts3d_cam_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "ray_directions": output_ray_directions_per_view[i], + "depth_along_ray": output_depth_along_ray_per_view[i] + * scale_final_output.unsqueeze(-1).unsqueeze(-1), + "cam_trans": output_cam_translations_per_view[i] + * scale_final_output, + "cam_quats": output_cam_quats_per_view[i], + "metric_scaling_factor": scale_final_output, + } + ) + else: + raise ValueError( + f"Invalid scene_rep_type: {self.scene_rep_type}. \ + Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \ + 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \ + 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \ + 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']" + ) + + # Get the output confidences for all views (if available) and add them to the result + if "confidence" in self.scene_rep_type: + output_confidences = dense_final_outputs.confidence + # Reshape confidences to (B * V, H, W) + output_confidences = ( + output_confidences.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + # Split the predicted confidences back to their respective views + output_confidences_per_view = output_confidences.chunk(num_views, dim=0) + # Add the confidences to the result + for i in range(num_views): + res[i]["conf"] = output_confidences_per_view[i] + + # Get the output masks (and logits) for all views (if available) and add them to the result + if "mask" in self.scene_rep_type: + # Get the output masks + output_masks = dense_final_outputs.mask + # Reshape masks to (B * V, H, W) + output_masks = output_masks.permute(0, 2, 3, 1).squeeze(-1).contiguous() + # Threshold the masks at 0.5 to get binary masks (0: ambiguous, 1: non-ambiguous) + output_masks = output_masks > 0.5 + # Split the predicted masks back to their respective views + output_masks_per_view = output_masks.chunk(num_views, dim=0) + # Get the output mask logits (for loss) + output_mask_logits = dense_final_outputs.logits + # Reshape mask logits to (B * V, H, W) + output_mask_logits = ( + output_mask_logits.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + # Split the predicted mask logits back to their respective views + output_mask_logits_per_view = output_mask_logits.chunk(num_views, dim=0) + # Add the masks and logits to the result + for i in range(num_views): + res[i]["non_ambiguous_mask"] = output_masks_per_view[i] + res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i] + + return res + + def _configure_geometric_input_config( + self, + use_calibration: bool, + use_depth: bool, + use_pose: bool, + use_depth_scale: bool, + use_pose_scale: bool, + ): + """ + Configure the geometric input configuration + """ + # Store original config for restoration + if not hasattr(self, "_original_geometric_config"): + self._original_geometric_config = dict(self.geometric_input_config) + + # Set the geometric input configuration + if not (use_calibration or use_depth or use_pose): + # No geometric inputs (images-only mode) + self.geometric_input_config.update( + { + "overall_prob": 0.0, + "dropout_prob": 1.0, + "ray_dirs_prob": 0.0, + "depth_prob": 0.0, + "cam_prob": 0.0, + "sparse_depth_prob": 0.0, + "depth_scale_norm_all_prob": 0.0, + "pose_scale_norm_all_prob": 0.0, + } + ) + else: + # Enable geometric inputs with deterministic behavior + self.geometric_input_config.update( + { + "overall_prob": 1.0, + "dropout_prob": 0.0, + "ray_dirs_prob": 1.0 if use_calibration else 0.0, + "depth_prob": 1.0 if use_depth else 0.0, + "cam_prob": 1.0 if use_pose else 0.0, + "sparse_depth_prob": 0.0, + "depth_scale_norm_all_prob": 0.0 if use_depth_scale else 1.0, + "pose_scale_norm_all_prob": 0.0 if use_pose_scale else 1.0, + } + ) + + def _restore_original_geometric_input_config(self): + """ + Restore original geometric input configuration + """ + if hasattr(self, "_original_geometric_config"): + self.geometric_input_config.update(self._original_geometric_config) + + @torch.inference_mode() + def infer( + self, + views: List[Dict[str, Any]], + memory_efficient_inference: bool = False, + use_amp: bool = True, + amp_dtype: str = "bf16", + apply_mask: bool = True, + mask_edges: bool = True, + edge_normal_threshold: float = 5.0, + edge_depth_threshold: float = 0.03, + apply_confidence_mask: bool = False, + confidence_percentile: float = 10, + ignore_calibration_inputs: bool = False, + ignore_depth_inputs: bool = False, + ignore_pose_inputs: bool = False, + ignore_depth_scale_inputs: bool = False, + ignore_pose_scale_inputs: bool = False, + ) -> List[Dict[str, torch.Tensor]]: + """ + User-friendly inference with strict input validation and automatic conversion. + + Args: + views: List of view dictionaries. Each dict can contain: + Required: + - 'img': torch.Tensor of shape (B, 3, H, W) - normalized RGB images + - 'data_norm_type': str - normalization type used to normalize the images (must be equal to self.model.encoder.data_norm_type) + + Optional Geometric Inputs (only one of intrinsics OR ray_directions): + - 'intrinsics': torch.Tensor of shape (B, 3, 3) - will be converted to ray directions + - 'ray_directions': torch.Tensor of shape (B, H, W, 3) - ray directions in camera frame + - 'depth_z': torch.Tensor of shape (B, H, W, 1) - Z depth in camera frame (intrinsics or ray_directions must be provided) + - 'camera_poses': torch.Tensor of shape (B, 4, 4) or tuple of (quats - (B, 4), trans - (B, 3)) - can be any world frame + - 'is_metric_scale': bool or torch.Tensor of shape (B,) - if not provided, defaults to True + + Optional Additional Info: + - 'instance': List[str] where length of list is B - instance info for each view + - 'idx': List[int] where length of list is B - index info for each view + - 'true_shape': List[tuple] where length of list is B - true shape info (H, W) for each view + + memory_efficient_inference: Whether to use memory-efficient inference for dense prediction heads (trades off speed). Defaults to False. + use_amp: Whether to use automatic mixed precision for faster inference. Defaults to True. + amp_dtype: The dtype to use for mixed precision. Defaults to "bf16" (bfloat16). Options: "fp16", "bf16", "fp32". + apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True. + mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True. + edge_normal_threshold: Tolerance threshold for normals-based edge detection. Defaults to 5.0. + edge_depth_threshold: Relative tolerance threshold for depth-based edge detection. Defaults to 0.03. + apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False. + confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10. + ignore_calibration_inputs: Whether to ignore the calibration inputs (intrinsics and ray_directions). Defaults to False. + ignore_depth_inputs: Whether to ignore the depth inputs. Defaults to False. + ignore_pose_inputs: Whether to ignore the pose inputs. Defaults to False. + ignore_depth_scale_inputs: Whether to ignore the depth scale inputs. Defaults to False. + ignore_pose_scale_inputs: Whether to ignore the pose scale inputs. Defaults to False. + + IMPORTANT CONSTRAINTS: + - Cannot provide both 'intrinsics' and 'ray_directions' (they represent the same information) + - If 'depth' is provided, then 'intrinsics' or 'ray_directions' must also be provided + - If ANY view has 'camera_poses', then view 0 (first view) MUST also have 'camera_poses' + + Returns: + List of prediction dictionaries, one per view. Each dict contains: + - 'img_no_norm': torch.Tensor of shape (B, H, W, 3) - denormalized rgb images + - 'pts3d': torch.Tensor of shape (B, H, W, 3) - predicted points in world frame + - 'pts3d_cam': torch.Tensor of shape (B, H, W, 3) - predicted points in camera frame + - 'ray_directions': torch.Tensor of shape (B, H, W, 3) - ray directions in camera frame + - 'intrinsics': torch.Tensor of shape (B, 3, 3) - pinhole camera intrinsics recovered from ray directions + - 'depth_along_ray': torch.Tensor of shape (B, H, W, 1) - depth along ray in camera frame + - 'depth_z': torch.Tensor of shape (B, H, W, 1) - Z depth in camera frame + - 'cam_trans': torch.Tensor of shape (B, 3) - camera translation in world frame + - 'cam_quats': torch.Tensor of shape (B, 4) - camera quaternion in world frame + - 'camera_poses': torch.Tensor of shape (B, 4, 4) - camera pose in world frame + - 'metric_scaling_factor': torch.Tensor of shape (B,) - applied metric scaling factor + - 'mask': torch.Tensor of shape (B, H, W, 1) - combo of non-ambiguous mask, edge mask and confidence-based mask if used + - 'non_ambiguous_mask': torch.Tensor of shape (B, H, W) - non-ambiguous mask + - 'non_ambiguous_mask_logits': torch.Tensor of shape (B, H, W) - non-ambiguous mask logits + - 'conf': torch.Tensor of shape (B, H, W) - confidence + + Raises: + ValueError: For invalid inputs, missing required keys, conflicting modalities, or constraint violations + """ + # Determine the mixed precision floating point type + if use_amp: + if amp_dtype == "fp16": + amp_dtype = torch.float16 + elif amp_dtype == "bf16": + if torch.cuda.is_bf16_supported(): + amp_dtype = torch.bfloat16 + else: + warnings.warn( + "bf16 is not supported on this device. Using fp16 instead." + ) + amp_dtype = torch.float16 + elif amp_dtype == "fp32": + amp_dtype = torch.float32 + else: + amp_dtype = torch.float32 + + # Validate the input views + validated_views = validate_input_views_for_inference(views) + + # Transfer the views to the same device as the model + ignore_keys = set( + [ + "instance", + "idx", + "true_shape", + "data_norm_type", + ] + ) + for view in validated_views: + for name in view.keys(): + if name in ignore_keys: + continue + view[name] = view[name].to(self.device, non_blocking=True) + + # Pre-process the input views + processed_views = preprocess_input_views_for_inference(validated_views) + + # Set the model input probabilities based on input args for ignoring inputs + self._configure_geometric_input_config( + use_calibration=not ignore_calibration_inputs, + use_depth=not ignore_depth_inputs, + use_pose=not ignore_pose_inputs, + use_depth_scale=not ignore_depth_scale_inputs, + use_pose_scale=not ignore_pose_scale_inputs, + ) + + # Run the model + with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype): + preds = self.forward( + processed_views, memory_efficient_inference=memory_efficient_inference + ) + + # Post-process the model outputs + preds = postprocess_model_outputs_for_inference( + raw_outputs=preds, + input_views=processed_views, + apply_mask=apply_mask, + mask_edges=mask_edges, + edge_normal_threshold=edge_normal_threshold, + edge_depth_threshold=edge_depth_threshold, + apply_confidence_mask=apply_confidence_mask, + confidence_percentile=confidence_percentile, + ) + + # Restore the original configuration + self._restore_original_geometric_input_config() + + return preds diff --git a/mapanything/models/mapanything/modular_dust3r.py b/mapanything/models/mapanything/modular_dust3r.py new file mode 100644 index 0000000000000000000000000000000000000000..672ac1b9368545c2b714c94abbbe1e35824cb6f7 --- /dev/null +++ b/mapanything/models/mapanything/modular_dust3r.py @@ -0,0 +1,475 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Modular DUSt3R class defined using UniCeption modules. +""" + +from typing import Callable, Dict + +import torch +import torch.nn as nn + +from uniception.models.encoders import encoder_factory, ViTEncoderInput +from uniception.models.info_sharing.alternating_attention_transformer import ( + MultiViewAlternatingAttentionTransformer, + MultiViewAlternatingAttentionTransformerIFR, +) +from uniception.models.info_sharing.base import MultiViewTransformerInput +from uniception.models.info_sharing.cross_attention_transformer import ( + MultiViewCrossAttentionTransformer, + MultiViewCrossAttentionTransformerIFR, +) +from uniception.models.info_sharing.global_attention_transformer import ( + MultiViewGlobalAttentionTransformer, + MultiViewGlobalAttentionTransformerIFR, +) +from uniception.models.libs.croco.pos_embed import RoPE2D +from uniception.models.prediction_heads.adaptors import PointMapWithConfidenceAdaptor +from uniception.models.prediction_heads.base import ( + AdaptorInput, + PredictionHeadInput, + PredictionHeadLayeredInput, +) +from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor +from uniception.models.prediction_heads.linear import LinearFeature + +# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) +if hasattr(torch.backends.cuda, "matmul") and hasattr( + torch.backends.cuda.matmul, "allow_tf32" +): + torch.backends.cuda.matmul.allow_tf32 = True + + +class ModularDUSt3R(nn.Module): + "Modular DUSt3R model class." + + def __init__( + self, + name: str, + encoder_config: Dict, + info_sharing_config: Dict, + pred_head_config: Dict, + pretrained_checkpoint_path: str = None, + load_specific_pretrained_submodules: bool = False, + specific_pretrained_submodules: list = [], + torch_hub_force_reload: bool = False, + *args, + **kwargs, + ): + """ + Two-view model containing siamese encoders followed by a two-view attention transformer and respective downstream heads. + The goal is to output scene representation directly, both outputs in view1's frame (hence the asymmetry). + + Args: + name (str): Name of the model. + encoder_config (Dict): Configuration for the encoder. + info_sharing_config (Dict): Configuration for the two-view attention transformer. + pred_head_config (Dict): Configuration for the prediction heads. + pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None) + load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False) + specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: []) + torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False) + """ + super().__init__(*args, **kwargs) + + # Initialize the attributes + self.name = name + self.encoder_config = encoder_config + self.info_sharing_config = info_sharing_config + self.pred_head_config = pred_head_config + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.load_specific_pretrained_submodules = load_specific_pretrained_submodules + self.specific_pretrained_submodules = specific_pretrained_submodules + self.torch_hub_force_reload = torch_hub_force_reload + self.class_init_args = { + "name": self.name, + "encoder_config": self.encoder_config, + "info_sharing_config": self.info_sharing_config, + "pred_head_config": self.pred_head_config, + "pretrained_checkpoint_path": self.pretrained_checkpoint_path, + "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules, + "specific_pretrained_submodules": self.specific_pretrained_submodules, + "torch_hub_force_reload": self.torch_hub_force_reload, + } + + # Get relevant parameters from the configs + custom_positional_encoding = info_sharing_config["custom_positional_encoding"] + self.info_sharing_type = info_sharing_config["model_type"] + self.info_sharing_return_type = info_sharing_config["model_return_type"] + self.pred_head_type = pred_head_config["type"] + + # Initialize Encoder + if self.encoder_config["uses_torch_hub"]: + self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload + # Create a copy of the config before deleting the key to preserve it for serialization + encoder_config_copy = self.encoder_config.copy() + del encoder_config_copy["uses_torch_hub"] + self.encoder = encoder_factory(**encoder_config_copy) + + # Initialize Custom Positional Encoding if required + if custom_positional_encoding is not None: + if isinstance(custom_positional_encoding, str): + print( + f"Using custom positional encoding for multi-view cross attention transformer: {custom_positional_encoding}" + ) + if custom_positional_encoding.startswith("RoPE"): + rope_freq = float(custom_positional_encoding[len("RoPE") :]) + print(f"RoPE frequency: {rope_freq}") + self.custom_positional_encoding = RoPE2D(freq=rope_freq) + else: + raise ValueError( + f"Invalid custom_positional_encoding: {custom_positional_encoding}." + ) + elif isinstance(custom_positional_encoding, Callable): + print( + "Using callable function as custom positional encoding for multi-view cross attention transformer." + ) + self.custom_positional_encoding = custom_positional_encoding + else: + self.custom_positional_encoding = None + + # Add dependencies to info_sharing_config + info_sharing_config["module_args"]["input_embed_dim"] = ( + self.encoder.enc_embed_dim + ) + info_sharing_config["module_args"]["custom_positional_encoding"] = ( + self.custom_positional_encoding + ) + + # Initialize Multi-View Transformer + if self.info_sharing_return_type == "no_intermediate_features": + # Returns only normalized last layer features + # Initialize multi-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformer( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformer( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + elif self.info_sharing_return_type == "intermediate_features": + # Returns intermediate features and normalized last layer features + # Initialize mulit-view transformer based on type + if self.info_sharing_type == "cross_attention": + self.info_sharing = MultiViewCrossAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "global_attention": + self.info_sharing = MultiViewGlobalAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + elif self.info_sharing_type == "alternating_attention": + self.info_sharing = MultiViewAlternatingAttentionTransformerIFR( + **info_sharing_config["module_args"] + ) + else: + raise ValueError( + f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']" + ) + # Assess if the DPT needs to use encoder features + if len(self.info_sharing.indices) == 2: + self.use_encoder_features_for_dpt = True + elif len(self.info_sharing.indices) == 3: + self.use_encoder_features_for_dpt = False + else: + raise ValueError( + "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices." + ) + else: + raise ValueError( + f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']" + ) + + # Add dependencies to prediction head config + pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size + if self.pred_head_type == "linear": + pred_head_config["feature_head"]["input_feature_dim"] = ( + self.info_sharing.dim + ) + elif self.pred_head_type == "dpt": + if self.use_encoder_features_for_dpt: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.encoder.enc_embed_dim + ] + [self.info_sharing.dim] * 3 + else: + pred_head_config["feature_head"]["input_feature_dims"] = [ + self.info_sharing.dim + ] * 4 + pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[ + "feature_head" + ]["feature_dim"] + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt']" + ) + + # Initialize Prediction Heads + if self.pred_head_type == "linear": + # Initialize Prediction Head 1 + self.head1 = LinearFeature(**pred_head_config["feature_head"]) + # Initialize Prediction Head 2 + self.head2 = LinearFeature(**pred_head_config["feature_head"]) + elif self.pred_head_type == "dpt": + # Initialize Prediction Head 1 + self.dpt_feature_head1 = DPTFeature(**pred_head_config["feature_head"]) + self.dpt_regressor_head1 = DPTRegressionProcessor( + **pred_head_config["regressor_head"] + ) + self.head1 = nn.Sequential(self.dpt_feature_head1, self.dpt_regressor_head1) + # Initialize Prediction Head 2 + self.dpt_feature_head2 = DPTFeature(**pred_head_config["feature_head"]) + self.dpt_regressor_head2 = DPTRegressionProcessor( + **pred_head_config["regressor_head"] + ) + self.head2 = nn.Sequential(self.dpt_feature_head2, self.dpt_regressor_head2) + else: + raise ValueError( + f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt']" + ) + + # Initialize Final Output Adaptor + if pred_head_config["adaptor_type"] == "pointmap+confidence": + self.adaptor = PointMapWithConfidenceAdaptor(**pred_head_config["adaptor"]) + self.scene_rep_type = "pointmap" + else: + raise ValueError( + f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. Valid options: ['pointmap+confidence']" + ) + + # Load pretrained weights + if self.pretrained_checkpoint_path is not None: + if not self.load_specific_pretrained_submodules: + print( + f"Loading pretrained weights from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + else: + print( + f"Loading pretrained weights from {self.pretrained_checkpoint_path} for specific submodules: {specific_pretrained_submodules} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + filtered_ckpt = {} + for ckpt_key, ckpt_value in ckpt["model"].items(): + for submodule in specific_pretrained_submodules: + if ckpt_key.startswith(submodule): + filtered_ckpt[ckpt_key] = ckpt_value + print(self.load_state_dict(filtered_ckpt, strict=False)) + + def _encode_image_pairs(self, img1, img2, data_norm_type): + "Encode two different batches of images (each batch can have different image shape)" + if img1.shape[-2:] == img2.shape[-2:]: + encoder_input = ViTEncoderInput( + image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type + ) + encoder_output = self.encoder(encoder_input) + out, out2 = encoder_output.features.chunk(2, dim=0) + else: + encoder_input = ViTEncoderInput(image=img1, data_norm_type=data_norm_type) + out = self.encoder(encoder_input) + out = out.features + encoder_input2 = ViTEncoderInput(image=img2, data_norm_type=data_norm_type) + out2 = self.encoder(encoder_input2) + out2 = out2.features + + return out, out2 + + def _encode_symmetrized(self, view1, view2): + "Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input" + img1 = view1["img"] + img2 = view2["img"] + if isinstance(view1["data_norm_type"], list): + assert all( + [x == view1["data_norm_type"][0] for x in view1["data_norm_type"]] + ), "All data_norm_type values should be the same in the list." + data_norm_type = view1["data_norm_type"][0] + elif isinstance(view1["data_norm_type"], str): + data_norm_type = view1["data_norm_type"] + else: + raise ValueError( + f"Invalid data_norm_type: {view1['data_norm_type']}. Should be either a list with all same values or a string." + ) + feat1, feat2 = self._encode_image_pairs( + img1, img2, data_norm_type=data_norm_type + ) + + return feat1, feat2 + + def _downstream_head(self, head_num, decout, img_shape): + "Run the respective prediction heads" + head = getattr(self, f"head{head_num}") + if self.pred_head_type == "linear": + head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) + elif self.pred_head_type == "dpt": + head_input = PredictionHeadLayeredInput( + list_features=decout[f"{head_num}"], target_output_shape=img_shape + ) + + return head(head_input) + + def forward(self, views): + """ + Forward pass performing the following operations: + 1. Encodes the two input views (images). + 2. Combines the encoded features using a two-view attention transformer. + 3. Passes the combined features through the respective prediction heads. + 4. Returns the processed final outputs for both views. + + Args: + views (List(dict)): A list of size two whose elements are: + view1 (dict): Dictionary containing the first view's images and instance information. + "img" is a required key and value is a tensor of shape (B, C, H, W). + view2 (dict): Dictionary containing the second view's images and instance information. + "img" is a required key and value is a tensor of shape (B, C, H, W). + + Returns: + List[dict, dict]: A list containing the final outputs for both views. + """ + # Get input shapes + view1 = views[0] + view2 = views[1] + _, _, height1, width1 = view1["img"].shape + _, _, height2, width2 = view2["img"].shape + shape1 = (int(height1), int(width1)) + shape2 = (int(height2), int(width2)) + + if "img_encoder_feats" in view1 and "img_encoder_feats" in view2: + # Reuse the pre-computed image features for the two views + feat1 = view1["img_encoder_feats"] + feat2 = view2["img_encoder_feats"] + else: + # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) + feat1, feat2 = self._encode_symmetrized(view1, view2) + + # Combine all images into view-centric representation + info_sharing_input = MultiViewTransformerInput(features=[feat1, feat2]) + if self.info_sharing_return_type == "no_intermediate_features": + final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input) + elif self.info_sharing_return_type == "intermediate_features": + ( + final_info_sharing_multi_view_feat, + intermediate_info_sharing_multi_view_feat, + ) = self.info_sharing(info_sharing_input) + + if self.pred_head_type == "linear": + # Define feature dictionary for linear head + info_sharing_outputs = { + "1": final_info_sharing_multi_view_feat.features[0].float(), + "2": final_info_sharing_multi_view_feat.features[1].float(), + } + elif self.pred_head_type == "dpt": + # Define feature dictionary for DPT head + if self.use_encoder_features_for_dpt: + info_sharing_outputs = { + "1": [ + feat1.float(), + intermediate_info_sharing_multi_view_feat[0] + .features[0] + .float(), + intermediate_info_sharing_multi_view_feat[1] + .features[0] + .float(), + final_info_sharing_multi_view_feat.features[0].float(), + ], + "2": [ + feat2.float(), + intermediate_info_sharing_multi_view_feat[0] + .features[1] + .float(), + intermediate_info_sharing_multi_view_feat[1] + .features[1] + .float(), + final_info_sharing_multi_view_feat.features[1].float(), + ], + } + else: + info_sharing_outputs = { + "1": [ + intermediate_info_sharing_multi_view_feat[0] + .features[0] + .float(), + intermediate_info_sharing_multi_view_feat[1] + .features[0] + .float(), + intermediate_info_sharing_multi_view_feat[2] + .features[0] + .float(), + final_info_sharing_multi_view_feat.features[0].float(), + ], + "2": [ + intermediate_info_sharing_multi_view_feat[0] + .features[1] + .float(), + intermediate_info_sharing_multi_view_feat[1] + .features[1] + .float(), + intermediate_info_sharing_multi_view_feat[2] + .features[1] + .float(), + final_info_sharing_multi_view_feat.features[1].float(), + ], + } + + # Downstream task prediction + with torch.autocast("cuda", enabled=False): + # Prediction heads + head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) + head_output2 = self._downstream_head(2, info_sharing_outputs, shape2) + + # Post-process outputs + final_output1 = self.adaptor( + AdaptorInput( + adaptor_feature=head_output1.decoded_channels, + output_shape_hw=shape1, + ) + ) + final_output2 = self.adaptor( + AdaptorInput( + adaptor_feature=head_output2.decoded_channels, + output_shape_hw=shape2, + ) + ) + + # Reshape final scene representation to (B, H, W, C) + final_scene_rep1 = final_output1.value.permute(0, 2, 3, 1).contiguous() + final_scene_rep2 = final_output2.value.permute(0, 2, 3, 1).contiguous() + + # Convert output scene representation to pointmaps + if self.scene_rep_type == "pointmap": + output_pts3d1 = final_scene_rep1 + output_pts3d2 = final_scene_rep2 + else: + raise ValueError(f"Invalid scene_rep_type: {self.scene_rep_type}.") + + # Reshape confidence to (B, H, W, 1) + output_conf1 = ( + final_output1.confidence.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + output_conf2 = ( + final_output2.confidence.permute(0, 2, 3, 1).squeeze(-1).contiguous() + ) + + # Convert outputs to dictionary + res1 = { + "pts3d": output_pts3d1, + "conf": output_conf1, + } + res2 = { + "pts3d": output_pts3d2, + "conf": output_conf2, + } + res = [res1, res2] + + return res diff --git a/mapanything/third_party/README.md b/mapanything/third_party/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6c1bbc96bbd47274068555bef2fa31335a88468d --- /dev/null +++ b/mapanything/third_party/README.md @@ -0,0 +1,3 @@ +# Third Party Code + +This folder contains third party code from VGGSfM & VGGT to support the COLMAP demo. diff --git a/mapanything/third_party/__init__.py b/mapanything/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/third_party/distortion.py b/mapanything/third_party/distortion.py new file mode 100644 index 0000000000000000000000000000000000000000..98d42e6674e351ae055959f8a0931ef1806b621f --- /dev/null +++ b/mapanything/third_party/distortion.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +from typing import Union + +import numpy as np +import torch + +ArrayLike = Union[np.ndarray, torch.Tensor] + + +def _is_numpy(x: ArrayLike) -> bool: + return isinstance(x, np.ndarray) + + +def _is_torch(x: ArrayLike) -> bool: + return isinstance(x, torch.Tensor) + + +def _ensure_torch(x: ArrayLike) -> torch.Tensor: + """Convert input to torch tensor if it's not already one.""" + if _is_numpy(x): + return torch.from_numpy(x) + elif _is_torch(x): + return x + else: + return torch.tensor(x) + + +def single_undistortion(params, tracks_normalized): + """ + Apply undistortion to the normalized tracks using the given distortion parameters once. + + Args: + params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. + tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. + + Returns: + torch.Tensor: Undistorted normalized tracks tensor. + """ + params = _ensure_torch(params) + tracks_normalized = _ensure_torch(tracks_normalized) + + u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() + u_undist, v_undist = apply_distortion(params, u, v) + return torch.stack([u_undist, v_undist], dim=-1) + + +def iterative_undistortion( + params, + tracks_normalized, + max_iterations=100, + max_step_norm=1e-10, + rel_step_size=1e-6, +): + """ + Iteratively undistort the normalized tracks using the given distortion parameters. + + Args: + params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. + tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. + max_iterations (int): Maximum number of iterations for the undistortion process. + max_step_norm (float): Maximum step norm for convergence. + rel_step_size (float): Relative step size for numerical differentiation. + + Returns: + torch.Tensor: Undistorted normalized tracks tensor. + """ + params = _ensure_torch(params) + tracks_normalized = _ensure_torch(tracks_normalized) + + B, N, _ = tracks_normalized.shape + u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() + original_u, original_v = u.clone(), v.clone() + + eps = torch.finfo(u.dtype).eps + for idx in range(max_iterations): + u_undist, v_undist = apply_distortion(params, u, v) + dx = original_u - u_undist + dy = original_v - v_undist + + step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps) + step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps) + + J_00 = ( + apply_distortion(params, u + step_u, v)[0] + - apply_distortion(params, u - step_u, v)[0] + ) / (2 * step_u) + J_01 = ( + apply_distortion(params, u, v + step_v)[0] + - apply_distortion(params, u, v - step_v)[0] + ) / (2 * step_v) + J_10 = ( + apply_distortion(params, u + step_u, v)[1] + - apply_distortion(params, u - step_u, v)[1] + ) / (2 * step_u) + J_11 = ( + apply_distortion(params, u, v + step_v)[1] + - apply_distortion(params, u, v - step_v)[1] + ) / (2 * step_v) + + J = torch.stack( + [ + torch.stack([J_00 + 1, J_01], dim=-1), + torch.stack([J_10, J_11 + 1], dim=-1), + ], + dim=-2, + ) + + delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1)) + + u += delta[..., 0] + v += delta[..., 1] + + if torch.max((delta**2).sum(dim=-1)) < max_step_norm: + break + + return torch.stack([u, v], dim=-1) + + +def apply_distortion(extra_params, u, v): + """ + Applies radial or OpenCV distortion to the given 2D points. + + Args: + extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks. + v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks. + + Returns: + points2D (torch.Tensor): Distorted 2D points of shape BxNx2. + """ + extra_params = _ensure_torch(extra_params) + u = _ensure_torch(u) + v = _ensure_torch(v) + + num_params = extra_params.shape[1] + + if num_params == 1: + # Simple radial distortion + k = extra_params[:, 0] + u2 = u * u + v2 = v * v + r2 = u2 + v2 + radial = k[:, None] * r2 + du = u * radial + dv = v * radial + + elif num_params == 2: + # RadialCameraModel distortion + k1, k2 = extra_params[:, 0], extra_params[:, 1] + u2 = u * u + v2 = v * v + r2 = u2 + v2 + radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 + du = u * radial + dv = v * radial + + elif num_params == 4: + # OpenCVCameraModel distortion + k1, k2, p1, p2 = ( + extra_params[:, 0], + extra_params[:, 1], + extra_params[:, 2], + extra_params[:, 3], + ) + u2 = u * u + v2 = v * v + uv = u * v + r2 = u2 + v2 + radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 + du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2) + dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2) + else: + raise ValueError("Unsupported number of distortion parameters") + + u = u.clone() + du + v = v.clone() + dv + + return u, v + + +if __name__ == "__main__": + import random + + import pycolmap + + max_diff = 0 + for i in range(1000): + # Define distortion parameters (assuming 1 parameter for simplicity) + B = random.randint(1, 500) + track_num = random.randint(100, 1000) + params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters + tracks_normalized = torch.rand( + (B, track_num, 2), dtype=torch.float32 + ) # Batch size 1, 5 points + + # Undistort the tracks + undistorted_tracks = iterative_undistortion(params, tracks_normalized) + + for b in range(B): + pycolmap_intri = np.array([1, 0, 0, params[b].item()]) + pycam = pycolmap.Camera( + model="SIMPLE_RADIAL", + width=1, + height=1, + params=pycolmap_intri, + camera_id=0, + ) + + undistorted_tracks_pycolmap = pycam.cam_from_img( + tracks_normalized[b].numpy() + ) + diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median() + max_diff = max(max_diff, diff) + print(f"diff: {diff}, max_diff: {max_diff}") + + import pdb + + pdb.set_trace() diff --git a/mapanything/third_party/np_to_pycolmap.py b/mapanything/third_party/np_to_pycolmap.py new file mode 100644 index 0000000000000000000000000000000000000000..5349f965f6437e245a0c911ce7334e493fd6aa8c --- /dev/null +++ b/mapanything/third_party/np_to_pycolmap.py @@ -0,0 +1,357 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import numpy as np +import pycolmap + +from mapanything.third_party.projection import project_3D_points_np + + +def batch_np_matrix_to_pycolmap( + points3d, + extrinsics, + intrinsics, + tracks, + image_size, + masks=None, + max_reproj_error=None, + max_points3D_val=3000, + shared_camera=False, + camera_type="SIMPLE_PINHOLE", + extra_params=None, + min_inlier_per_frame=64, + points_rgb=None, +): + """ + Convert Batched NumPy Arrays to PyCOLMAP + + Check https://github.com/colmap/pycolmap for more details about its format + + NOTE that colmap expects images/cameras/points3D to be 1-indexed + so there is a +1 offset between colmap index and batch index + + + NOTE: different from VGGSfM, this function: + 1. Use np instead of torch + 2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP) + """ + # points3d: Px3 + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # tracks: NxPx2 + # masks: NxP + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N, P, _ = tracks.shape + assert len(extrinsics) == N + assert len(intrinsics) == N + assert len(points3d) == P + assert image_size.shape[0] == 2 + + reproj_mask = None + + if max_reproj_error is not None: + projected_points_2d, projected_points_cam = project_3D_points_np( + points3d, extrinsics, intrinsics + ) + projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1) + projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6 + reproj_mask = projected_diff < max_reproj_error + + if masks is not None and reproj_mask is not None: + masks = np.logical_and(masks, reproj_mask) + elif masks is not None: + masks = masks + else: + masks = reproj_mask + + assert masks is not None + + if masks.sum(1).min() < min_inlier_per_frame: + print("Not enough inliers per frame, skip BA.") + return None, None + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + inlier_num = masks.sum(0) + valid_mask = inlier_num >= 2 # a track is invalid if without two inliers + valid_idx = np.nonzero(valid_mask)[0] + + # Only add 3D points that have sufficient 2D points + for vidx in valid_idx: + # Use RGB colors if provided, otherwise use zeros + rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3) + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb) + + num_points3D = len(valid_idx) + camera = None + # frame idx + for fidx in range(N): + # set camera + if camera is None or (not shared_camera): + pycolmap_intri = _build_pycolmap_intri( + fidx, intrinsics, camera_type, extra_params + ) + + camera = pycolmap.Camera( + model=camera_type, + width=image_size[0], + height=image_size[1], + params=pycolmap_intri, + camera_id=fidx + 1, + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + + image = pycolmap.Image( + id=fidx + 1, + name=f"image_{fidx + 1}", + camera_id=camera.camera_id, + cam_from_world=cam_from_world, + ) + + points2D_list = [] + + point2D_idx = 0 + + # NOTE point3D_id start by 1 + for point3D_id in range(1, num_points3D + 1): + original_track_idx = valid_idx[point3D_id - 1] + + if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all(): + if masks[fidx][original_track_idx]: + # It seems we don't need +0.5 for BA + point2D_xy = tracks[fidx][original_track_idx] + # Please note when adding the Point2D object + # It not only requires the 2D xy location, but also the id to 3D point + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + track.add_element(fidx + 1, point2D_idx) + point2D_idx += 1 + + assert point2D_idx == len(points2D_list) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + image.registered = True + except: # noqa + print(f"frame {fidx + 1} is out of BA") + image.registered = False + + # add image + reconstruction.add_image(image) + + return reconstruction, valid_mask + + +def pycolmap_to_batch_np_matrix( + reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE" +): + """ + Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays. + + Args: + reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP. + device (str): Ignored in NumPy version (kept for API compatibility). + camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE"). + + Returns: + tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params. + """ + + num_images = len(reconstruction.images) + max_points3D_id = max(reconstruction.point3D_ids()) + points3D = np.zeros((max_points3D_id, 3)) + + for point3D_id in reconstruction.points3D: + points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz + + extrinsics = [] + intrinsics = [] + + extra_params = [] if camera_type == "SIMPLE_RADIAL" else None + + for i in range(num_images): + # Extract and append extrinsics + pyimg = reconstruction.images[i + 1] + pycam = reconstruction.cameras[pyimg.camera_id] + matrix = pyimg.cam_from_world.matrix() + extrinsics.append(matrix) + + # Extract and append intrinsics + calibration_matrix = pycam.calibration_matrix() + intrinsics.append(calibration_matrix) + + if camera_type == "SIMPLE_RADIAL": + extra_params.append(pycam.params[-1]) + + # Convert lists to NumPy arrays instead of torch tensors + extrinsics = np.stack(extrinsics) + intrinsics = np.stack(intrinsics) + + if camera_type == "SIMPLE_RADIAL": + extra_params = np.stack(extra_params) + extra_params = extra_params[:, None] + + return points3D, extrinsics, intrinsics, extra_params + + +######################################################## + + +def batch_np_matrix_to_pycolmap_wo_track( + points3d, + points_xyf, + points_rgb, + extrinsics, + intrinsics, + image_size, + shared_camera=False, + camera_type="SIMPLE_PINHOLE", +): + """ + Convert Batched NumPy Arrays to PyCOLMAP + + Different from batch_np_matrix_to_pycolmap, this function does not use tracks. + + It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods. + + Do NOT use this for BA. + """ + # points3d: Px3 + # points_xyf: Px3, with x, y coordinates and frame indices + # points_rgb: Px3, rgb colors + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N = len(extrinsics) + P = len(points3d) + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + for vidx in range(P): + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx]) + + camera = None + # frame idx + for fidx in range(N): + # set camera + if camera is None or (not shared_camera): + pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type) + + camera = pycolmap.Camera( + model=camera_type, + width=image_size[0], + height=image_size[1], + params=pycolmap_intri, + camera_id=fidx + 1, + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + + image = pycolmap.Image( + id=fidx + 1, + name=f"image_{fidx + 1}", + camera_id=camera.camera_id, + cam_from_world=cam_from_world, + ) + + points2D_list = [] + + point2D_idx = 0 + + points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx + points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0] + + for point3D_batch_idx in points_belong_to_fidx: + point3D_id = point3D_batch_idx + 1 + point2D_xyf = points_xyf[point3D_batch_idx] + point2D_xy = point2D_xyf[:2] + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + track.add_element(fidx + 1, point2D_idx) + point2D_idx += 1 + + assert point2D_idx == len(points2D_list) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + image.registered = True + except: # noqa + print(f"frame {fidx + 1} does not have any points") + image.registered = False + + # add image + reconstruction.add_image(image) + + return reconstruction + + +def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None): + """ + Helper function to get camera parameters based on camera type. + + Args: + fidx: Frame index + intrinsics: Camera intrinsic parameters + camera_type: Type of camera model + extra_params: Additional parameters for certain camera types + + Returns: + pycolmap_intri: NumPy array of camera parameters + """ + if camera_type == "PINHOLE": + pycolmap_intri = np.array( + [ + intrinsics[fidx][0, 0], + intrinsics[fidx][1, 1], + intrinsics[fidx][0, 2], + intrinsics[fidx][1, 2], + ] + ) + elif camera_type == "SIMPLE_PINHOLE": + focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 + pycolmap_intri = np.array( + [focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]] + ) + elif camera_type == "SIMPLE_RADIAL": + raise NotImplementedError("SIMPLE_RADIAL is not supported yet") + focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 + pycolmap_intri = np.array( + [ + focal, + intrinsics[fidx][0, 2], + intrinsics[fidx][1, 2], + extra_params[fidx][0], + ] + ) + else: + raise ValueError(f"Camera type {camera_type} is not supported yet") + + return pycolmap_intri diff --git a/mapanything/third_party/projection.py b/mapanything/third_party/projection.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7a6657205505c51d9824487cc3070e95c665f7 --- /dev/null +++ b/mapanything/third_party/projection.py @@ -0,0 +1,250 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import numpy as np +import torch + +from .distortion import apply_distortion + + +def img_from_cam_np( + intrinsics: np.ndarray, + points_cam: np.ndarray, + extra_params: np.ndarray | None = None, + default: float = 0.0, +) -> np.ndarray: + """ + Apply intrinsics (and optional radial distortion) to camera-space points. + + Args + ---- + intrinsics : (B,3,3) camera matrix K. + points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ. + extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None. + default : value used for np.nan replacement. + + Returns + ------- + points2D : (B,N,2) pixel coordinates. + """ + # 1. perspective divide ─────────────────────────────────────── + z = points_cam[:, 2:3, :] # (B,1,N) + points_cam_norm = points_cam / z # (B,3,N) + uv = points_cam_norm[:, :2, :] # (B,2,N) + + # 2. optional distortion ────────────────────────────────────── + if extra_params is not None: + uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) + uv = np.stack([uu, vv], axis=1) # (B,2,N) + + # 3. homogeneous coords then K multiplication ───────────────── + ones = np.ones_like(uv[:, :1, :]) # (B,1,N) + points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N) + + # batched mat-mul: K · [u v 1]ᵀ + points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N) + points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N) + + return points2D.transpose(0, 2, 1) # (B,N,2) + + +def project_3D_points_np( + points3D: np.ndarray, + extrinsics: np.ndarray, + intrinsics: np.ndarray | None = None, + extra_params: np.ndarray | None = None, + *, + default: float = 0.0, + only_points_cam: bool = False, +): + """ + NumPy clone of ``project_3D_points``. + + Parameters + ---------- + points3D : (N,3) world-space points. + extrinsics : (B,3,4) [R|t] matrix for each of B cameras. + intrinsics : (B,3,3) K matrix (optional if you only need cam-space). + extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None. + default : value used to replace NaNs. + only_points_cam : if True, skip the projection and return points_cam with points2D as None. + + Returns + ------- + (points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True, + and points_cam is (B,3,N) camera-space coordinates. + """ + # ----- 0. prep sizes ----------------------------------------------------- + N = points3D.shape[0] # #points + B = extrinsics.shape[0] # #cameras + + # ----- 1. world → homogeneous ------------------------------------------- + w_h = np.ones((N, 1), dtype=points3D.dtype) + points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4) + + # broadcast to every camera (no actual copying with np.broadcast_to) ------ + points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4) + + # ----- 2. apply extrinsics (camera frame) ------------------------------ + # X_cam = E · X_hom + # einsum: E_(b i j) · X_(b n j) → (b n i) + points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3) + points_cam = points_cam.transpose(0, 2, 1) # (B,3,N) + + if only_points_cam: + return None, points_cam + + # ----- 3. intrinsics + distortion --------------------------------------- + if intrinsics is None: + raise ValueError("`intrinsics` must be provided unless only_points_cam=True") + + points2D = img_from_cam_np( + intrinsics, points_cam, extra_params=extra_params, default=default + ) + + return points2D, points_cam + + +def project_3D_points( + points3D, + extrinsics, + intrinsics=None, + extra_params=None, + default=0, + only_points_cam=False, +): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + points3D (torch.Tensor): 3D points of shape Px3. + extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. + intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. + extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion. + default (float): Default value to replace NaNs. + only_points_cam (bool): If True, skip the projection and return points2D as None. + + Returns: + tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True, + and points_cam is of shape Bx3xN. + """ + with torch.cuda.amp.autocast(dtype=torch.double): + B = extrinsics.shape[0] # Batch size, i.e., number of cameras + points3D_homogeneous = torch.cat( + [points3D, torch.ones_like(points3D[..., 0:1])], dim=1 + ) # Nx4 + # Reshape for batch processing + points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand( + B, -1, -1 + ) # BxNx4 + + # Step 1: Apply extrinsic parameters + # Transform 3D points to camera coordinate system for all cameras + points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2)) + + if only_points_cam: + return None, points_cam + + # Step 2: Apply intrinsic parameters and (optional) distortion + points2D = img_from_cam(intrinsics, points_cam, extra_params, default) + + return points2D, points_cam + + +def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0): + """ + Applies intrinsic parameters and optional distortion to the given 3D points. + + Args: + intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. + points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. + extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + default (float, optional): Default value to replace NaNs in the output. + + Returns: + points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. + """ + + # Normalize by the third coordinate (homogeneous division) + points_cam = points_cam / points_cam[:, 2:3, :] + # Extract uv + uv = points_cam[:, :2, :] + + # Apply distortion if extra_params are provided + if extra_params is not None: + uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) + uv = torch.stack([uu, vv], dim=1) + + # Prepare points_cam for batch matrix multiplication + points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN + # Apply intrinsic parameters using batch matrix multiplication + points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN + + # Extract x and y coordinates + points2D = points2D_homo[:, :2, :] # Bx2xN + + # Replace NaNs with default value + points2D = torch.nan_to_num(points2D, nan=default) + + return points2D.transpose(1, 2) # BxNx2 + + +if __name__ == "__main__": + # Set up example input + B, N = 24, 10240 + + for _ in range(100): + points3D = np.random.rand(N, 3).astype(np.float64) + extrinsics = np.random.rand(B, 3, 4).astype(np.float64) + intrinsics = np.random.rand(B, 3, 3).astype(np.float64) + + # Convert to torch tensors + points3D_torch = torch.tensor(points3D) + extrinsics_torch = torch.tensor(extrinsics) + intrinsics_torch = torch.tensor(intrinsics) + + # Run NumPy implementation + points2D_np, points_cam_np = project_3D_points_np( + points3D, extrinsics, intrinsics + ) + + # Run torch implementation + points2D_torch, points_cam_torch = project_3D_points( + points3D_torch, extrinsics_torch, intrinsics_torch + ) + + # Convert torch output to numpy + points2D_torch_np = points2D_torch.detach().numpy() + points_cam_torch_np = points_cam_torch.detach().numpy() + + # Compute difference + diff = np.abs(points2D_np - points2D_torch_np) + print("Difference between NumPy and PyTorch implementations:") + print(diff) + + # Check max error + max_diff = np.max(diff) + print(f"Maximum difference: {max_diff}") + + if np.allclose(points2D_np, points2D_torch_np, atol=1e-6): + print("Implementations match closely.") + else: + print("Significant differences detected.") + + if points_cam_np is not None: + points_cam_diff = np.abs(points_cam_np - points_cam_torch_np) + print("Difference between NumPy and PyTorch camera-space coordinates:") + print(points_cam_diff) + + # Check max error + max_cam_diff = np.max(points_cam_diff) + print(f"Maximum camera-space coordinate difference: {max_cam_diff}") + + if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6): + print("Camera-space coordinates match closely.") + else: + print("Significant differences detected in camera-space coordinates.") diff --git a/mapanything/third_party/track_modules/__init__.py b/mapanything/third_party/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/third_party/track_modules/base_track_predictor.py b/mapanything/third_party/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..4133fc702175d0c4cbcbf996744a54c1ae8c21c5 --- /dev/null +++ b/mapanything/third_party/track_modules/base_track_predictor.py @@ -0,0 +1,212 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import torch +import torch.nn as nn +from einops import rearrange + +from .blocks import CorrBlock, EfficientUpdateFormer +from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=4, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + fine=False, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.fine = fine + + self.flows_emb_dim = latent_dim // 2 + self.transformer_dim = ( + self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2 + ) + + if self.fine: + # TODO this is the old dummy code, will remove this when we train next model + self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5 + else: + self.transformer_dim += (4 - self.transformer_dim % 4) % 4 + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), nn.GELU() + ) + + if not self.fine: + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward( + self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1 + ): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2 + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + # Construct the correlation block + + fcorr_fn = CorrBlock( + fmaps, num_levels=self.corr_levels, radius=self.corr_radius + ) + + coord_preds = [] + + # Iterative Refinement + for itr in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + # Compute the correlation (check the implementation of CorrBlock) + + fcorr_fn.corr(track_feats) + fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim + + corrdim = fcorrs.shape[3] + + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape( + B * N, S, self.latent_dim + ) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + if transformer_input.shape[2] < self.transformer_dim: + # pad the features to match the dimension + pad_dim = self.transformer_dim - transformer_input.shape[2] + pad = torch.zeros_like(flows_emb[..., 0:pad_dim]) + transformer_input = torch.cat([transformer_input, pad], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed( + self.transformer_dim, grid_size=(HH, WW) + ).to(query_points.device) + sampled_pos_emb = sample_features4d( + pos_embed.expand(B, -1, -1, -1), coords[:, 0] + ) + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze( + 1 + ) + + x = transformer_input + sampled_pos_emb + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta = self.updateformer(x) + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_ + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute( + 0, 2, 1, 3 + ) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + if not self.fine: + vis_e = self.vis_predictor( + track_feats.reshape(B * S * N, self.latent_dim) + ).reshape(B, S, N) + vis_e = torch.sigmoid(vis_e) + else: + vis_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat + else: + return coord_preds, vis_e diff --git a/mapanything/third_party/track_modules/blocks.py b/mapanything/third_party/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..72a87809db07a5228b0837377ba8d991625ebcf5 --- /dev/null +++ b/mapanything/third_party/track_modules/blocks.py @@ -0,0 +1,389 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modules import AttnBlock, CrossAttnBlock, ResidualBlock +from .utils import bilinear_sampler + + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, stride=4): + super(BasicEncoder, self).__init__() + + self.stride = stride + self.norm_fn = "instance" + self.in_planes = output_dim // 2 + + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + + self.conv1 = nn.Conv2d( + input_dim, + self.in_planes, + kernel_size=7, + stride=2, + padding=3, + padding_mode="zeros", + ) + self.relu1 = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(output_dim // 2, stride=1) + self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) + self.layer3 = self._make_layer(output_dim, stride=2) + self.layer4 = self._make_layer(output_dim, stride=2) + + self.conv2 = nn.Conv2d( + output_dim * 3 + output_dim // 4, + output_dim * 2, + kernel_size=3, + padding=1, + padding_mode="zeros", + ) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.InstanceNorm2d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + a = self.layer1(x) + b = self.layer2(a) + c = self.layer3(b) + d = self.layer4(c) + + a = _bilinear_intepolate(a, self.stride, H, W) + b = _bilinear_intepolate(b, self.stride, H, W) + c = _bilinear_intepolate(c, self.stride, H, W) + d = _bilinear_intepolate(d, self.stride, H, W) + + x = self.conv2(torch.cat([a, b, c, d], dim=1)) + x = self.norm2(x) + x = self.relu2(x) + x = self.conv3(x) + return x + + +class ShallowEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"): + super(ShallowEncoder, self).__init__() + self.stride = stride + self.norm_fn = norm_fn + self.in_planes = output_dim + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(self.in_planes) + self.norm2 = nn.BatchNorm2d(output_dim * 2) + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d( + input_dim, + self.in_planes, + kernel_size=3, + stride=2, + padding=1, + padding_mode="zeros", + ) + self.relu1 = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(output_dim, stride=2) + + self.layer2 = self._make_layer(output_dim, stride=2) + self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + self.in_planes = dim + + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + return layer1 + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + tmp = self.layer1(x) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = self.layer2(tmp) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = None + x = self.conv2(x) + x + + x = F.interpolate( + x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True + ) + + return x + + +def _bilinear_intepolate(x, stride, H, W): + return F.interpolate( + x, (H // stride, W // stride), mode="bilinear", align_corners=True + ) + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter( + torch.randn(1, num_virtual_tracks, 1, hidden_size) + ) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [ + CrossAttnBlock( + hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio + ) + for _ in range(space_depth) + ] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [ + CrossAttnBlock( + hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio + ) + for _ in range(space_depth) + ] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + torch.nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, input_tensor, mask=None): + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and ( + i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0 + ): + space_tokens = ( + tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) + ) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j]( + virtual_tokens, point_tokens, mask=mask + ) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j]( + point_tokens, virtual_tokens, mask=mask + ) + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute( + 0, 2, 1, 3 + ) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + flow = self.flow_head(tokens) + return flow + + +class CorrBlock: + def __init__( + self, + fmaps, + num_levels=4, + radius=4, + multiple_track_feats=False, + padding_mode="zeros", + ): + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.padding_mode = padding_mode + self.num_levels = num_levels + self.radius = radius + self.fmaps_pyramid = [] + self.multiple_track_feats = multiple_track_feats + + self.fmaps_pyramid.append(fmaps) + for i in range(self.num_levels - 1): + fmaps_ = fmaps.reshape(B * S, C, H, W) + fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) + _, _, H, W = fmaps_.shape + fmaps = fmaps_.reshape(B, S, C, H, W) + self.fmaps_pyramid.append(fmaps) + + def sample(self, coords): + r = self.radius + B, S, N, D = coords.shape + assert D == 2 + + H, W = self.H, self.W + out_pyramid = [] + for i in range(self.num_levels): + corrs = self.corrs_pyramid[i] # B, S, N, H, W + *_, H, W = corrs.shape + + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( + coords.device + ) + + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corrs = bilinear_sampler( + corrs.reshape(B * S * N, 1, H, W), + coords_lvl, + padding_mode=self.padding_mode, + ) + corrs = corrs.view(B, S, N, -1) + + out_pyramid.append(corrs) + + out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2 + return out + + def corr(self, targets): + B, S, N, C = targets.shape + if self.multiple_track_feats: + targets_split = targets.split(C // self.num_levels, dim=-1) + B, S, N, C = targets_split[0].shape + + assert C == self.C + assert S == self.S + + fmap1 = targets + + self.corrs_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + *_, H, W = fmaps.shape + fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) + if self.multiple_track_feats: + fmap1 = targets_split[i] + corrs = torch.matmul(fmap1, fmap2s) + corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W + corrs = corrs / torch.sqrt(torch.tensor(C).float()) + self.corrs_pyramid.append(corrs) diff --git a/mapanything/third_party/track_modules/modules.py b/mapanything/third_party/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..ee04b9695c407ed915c20c34f0e9480cb0aa0a1e --- /dev/null +++ b/mapanything/third_party/track_modules/modules.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + + +import collections +from functools import partial +from itertools import repeat +from typing import Callable + +import torch.nn as nn + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=kernel_size, + padding=1, + stride=stride, + padding_mode="zeros", + ) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros" + ) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.attn = attn_class( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__( + self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs + ): + """ + Cross attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/mapanything/third_party/track_modules/track_refine.py b/mapanything/third_party/track_modules/track_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..fab54edb622030ee31314e9a8c9017790b6dd177 --- /dev/null +++ b/mapanything/third_party/track_modules/track_refine.py @@ -0,0 +1,486 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +from typing import Tuple + +import torch +from einops import rearrange + + +def refine_track( + images, + fine_fnet, + fine_tracker, + coarse_pred, + compute_score=False, + pradius=15, + sradius=2, + fine_iters=6, + chunk=40960, +): + """ + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = ( + torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + ) + + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[ + batch_indices, :, topleft[..., 1], topleft[..., 0] + ] + + if chunk < 0: + # Extract image patches based on top left corners + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + else: + patches = extracted_patches.reshape(B * S * N, C_in, psize, psize) + + patch_feat_list = [] + for p in torch.split(patches, chunk): + patch_feat_list += [fine_fnet(p)] + patch_feat = torch.cat(patch_feat_list, 0) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, + fmaps=patch_feat, + iters=fine_iters, + return_feat=True, + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange( + fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N + ) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn( + query_point_feat, + patch_feat, + fine_pred_track, + sradius, + psize, + B, + N, + S, + C_out, + ) + + return refined_tracks, score + + +def refine_track_v0( + images, + fine_fnet, + fine_tracker, + coarse_pred, + compute_score=False, + pradius=15, + sradius=2, + fine_iters=6, +): + """ + COPIED FROM VGGSfM + + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = ( + torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + ) + + # Extract image patches based on top left corners + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[ + batch_indices, :, topleft[..., 1], topleft[..., 0] + ] + + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, + fmaps=patch_feat, + iters=fine_iters, + return_feat=True, + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange( + fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N + ) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn( + query_point_feat, + patch_feat, + fine_pred_track, + sradius, + psize, + B, + N, + S, + C_out, + ) + + return refined_tracks, score + + +################################## NOTE: NOT USED ################################## + + +def compute_score_fn( + query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out +): + """ + Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps, + given the query point features and reference frame feature maps + """ + + from kornia.geometry.subpix import dsnt + from kornia.utils.grid import create_meshgrid + + # query_point_feat initial shape: B x N x C_out, + # query_point_feat indicates the feat at the coorponsing query points + # Therefore we don't have S dimension here + query_point_feat = query_point_feat.reshape(B, N, C_out) + # reshape and expand to B x (S-1) x N x C_out + query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1) + # and reshape to (B*(S-1)*N) x C_out + query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out) + + # Radius and size for computing the score + ssize = sradius * 2 + 1 + + # Reshape, you know it, so many reshaping operations + patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N) + + # Again, we unfold the patches to smaller patches + # so that we can then focus on smaller patches + # patch_feat_unfold shape: + # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize + # well a bit scary, but actually not + patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1) + + # Do the same stuffs above, i.e., the same as extracting patches + fine_prediction_floor = fine_pred_track.floor().int() + fine_level_floor_topleft = fine_prediction_floor - sradius + + # Clamp to ensure the smaller patch is valid + fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize) + fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2) + + # Prepare the batch indices and xy locations + + batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN + batch_indices_score = batch_indices_score.reshape(-1).to( + patch_feat_unfold.device + ) # B*S*N + y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices + x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices + + reference_frame_feat = patch_feat_unfold.reshape( + B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize + ) + + # Note again, according to pytorch convention + # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0] + reference_frame_feat = reference_frame_feat[ + batch_indices_score, :, x_indices, y_indices + ] + reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize) + # pick the frames other than the first one, so we have S-1 frames here + reference_frame_feat = reference_frame_feat[:, 1:].reshape( + B * (S - 1) * N, C_out, ssize * ssize + ) + + # Compute similarity + sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat) + softmax_temp = 1.0 / C_out**0.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) + # 2D heatmaps + heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize + + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] + grid_normalized = create_meshgrid( + ssize, ssize, normalized_coordinates=True, device=heatmap.device + ).reshape(1, -1, 2) + + var = ( + torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) + - coords_normalized**2 + ) + std = torch.sum( + torch.sqrt(torch.clamp(var, min=1e-10)), -1 + ) # clamp needed for numerical stability + + score = std.reshape(B, S - 1, N) + # set score as 1 for the query frame + score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1) + + return score + + +def extract_glimpse( + tensor: torch.Tensor, + size: Tuple[int, int], + offsets, + mode="bilinear", + padding_mode="zeros", + debug=False, + orib=None, +): + B, C, W, H = tensor.shape + + h, w = size + xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0 + ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0 + + vy, vx = torch.meshgrid(ys, xs) + grid = torch.stack([vx, vy], dim=-1) # h, w, 2 + grid = grid[None] + + B, N, _ = offsets.shape + + offsets = offsets.reshape((B * N), 1, 1, 2) + offsets_grid = offsets + grid + + # normalised grid to [-1, 1] + offsets_grid = ( + offsets_grid - offsets_grid.new_tensor([W / 2, H / 2]) + ) / offsets_grid.new_tensor([W / 2, H / 2]) + + # BxCxHxW -> Bx1xCxHxW + tensor = tensor[:, None] + + # Bx1xCxHxW -> BxNxCxHxW + tensor = tensor.expand(-1, N, -1, -1, -1) + + # BxNxCxHxW -> (B*N)xCxHxW + tensor = tensor.reshape((B * N), C, W, H) + + sampled = torch.nn.functional.grid_sample( + tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode + ) + + # NOTE: I am not sure it should be h, w or w, h here + # but okay for sqaures + sampled = sampled.reshape(B, N, C, h, w) + + return sampled diff --git a/mapanything/third_party/track_modules/utils.py b/mapanything/third_party/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a69bdd27fabb9279c26ee6bd95fa2b6aa1137390 --- /dev/null +++ b/mapanything/third_party/track_modules/utils.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + + +from typing import Tuple, Union + +import torch +import torch.nn.functional as F + + +def get_2d_sincos_pos_embed( + embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False +) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return ( + pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), + grid, + ) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: torch.Tensor +) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: torch.Tensor +) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = ( + torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) + ).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + coords = coords.detach().clone() + ############################################################ + # IMPORTANT: + coords = coords.to(input.device).to(input.dtype) + ############################################################ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + scale = torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], + device=coords.device, + dtype=coords.dtype, + ) + else: + scale = torch.tensor( + [2 / size for size in reversed(sizes)], + device=coords.device, + dtype=coords.dtype, + ) + + coords.mul_(scale) # coords = coords * scale + coords.sub_(1) # coords = coords - 1 + + return F.grid_sample( + input, coords, align_corners=align_corners, padding_mode=padding_mode + ) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view( + B, -1, feats.shape[1] * feats.shape[3] + ) # B C R 1 -> B R C diff --git a/mapanything/third_party/track_predict.py b/mapanything/third_party/track_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..0054ec0a143452877af498e543966840241d7c58 --- /dev/null +++ b/mapanything/third_party/track_predict.py @@ -0,0 +1,353 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import numpy as np +import torch + +from .vggsfm_utils import ( + build_vggsfm_tracker, + calculate_index_mappings, + extract_keypoints, + generate_rank_by_dino, + initialize_feature_extractors, + predict_tracks_in_chunks, + switch_tensor_order, +) + + +def predict_tracks( + images, + conf=None, + points_3d=None, + max_query_pts=2048, + query_frame_num=5, + keypoint_extractor="aliked+sp", + max_points_num=163840, + fine_tracking=True, + complete_non_vis=True, +): + """ + Predict tracks for the given images and masks. + + TODO: support non-square images + TODO: support masks + + + This function predicts the tracks for the given images and masks using the specified query method + and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames. + + Args: + images: Tensor of shape [S, 3, H, W] containing the input images. + conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None. + points_3d: Tensor containing 3D points. Default is None. + max_query_pts: Maximum number of query points. Default is 2048. + query_frame_num: Number of query frames to use. Default is 5. + keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp". + max_points_num: Maximum number of points to process at once. Default is 163840. + fine_tracking: Whether to use fine tracking. Default is True. + complete_non_vis: Whether to augment non-visible frames. Default is True. + + Returns: + pred_tracks: Numpy array containing the predicted tracks. + pred_vis_scores: Numpy array containing the visibility scores for the tracks. + pred_confs: Numpy array containing the confidence scores for the tracks. + pred_points_3d: Numpy array containing the 3D points for the tracks. + pred_colors: Numpy array containing the point colors for the tracks. (0, 255) + """ + + device = images.device + dtype = images.dtype + tracker = build_vggsfm_tracker().to(device, dtype) + + # Find query frames + query_frame_indexes = generate_rank_by_dino( + images, query_frame_num=query_frame_num, device=device + ) + + # Add the first image to the front if not already present + if 0 in query_frame_indexes: + query_frame_indexes.remove(0) + query_frame_indexes = [0, *query_frame_indexes] + + # TODO: add the functionality to handle the masks + keypoint_extractors = initialize_feature_extractors( + max_query_pts, extractor_method=keypoint_extractor, device=device + ) + + pred_tracks = [] + pred_vis_scores = [] + pred_confs = [] + pred_points_3d = [] + pred_colors = [] + + fmaps_for_tracker = tracker.process_images_to_fmaps(images) + + if fine_tracking: + print("For faster inference, consider disabling fine_tracking") + + for query_index in query_frame_indexes: + print(f"Predicting tracks for query frame {query_index}") + pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + device, + ) + + pred_tracks.append(pred_track) + pred_vis_scores.append(pred_vis) + pred_confs.append(pred_conf) + pred_points_3d.append(pred_point_3d) + pred_colors.append(pred_color) + + if complete_non_vis: + pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = ( + _augment_non_visible_frames( + pred_tracks, + pred_vis_scores, + pred_confs, + pred_points_3d, + pred_colors, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + min_vis=500, + non_vis_thresh=0.1, + device=device, + ) + ) + + pred_tracks = np.concatenate(pred_tracks, axis=1) + pred_vis_scores = np.concatenate(pred_vis_scores, axis=1) + pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None + pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None + pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None + + # from vggt.utils.visual_track import visualize_tracks_on_images + # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals") + + return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors + + +def _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + device, +): + """ + Process a single query frame for track prediction. + + Args: + query_index: Index of the query frame + images: Tensor of shape [S, 3, H, W] containing the input images + conf: Confidence tensor + points_3d: 3D points tensor + fmaps_for_tracker: Feature maps for the tracker + keypoint_extractors: Initialized feature extractors + tracker: VGG-SFM tracker + max_points_num: Maximum number of points to process at once + fine_tracking: Whether to use fine tracking + device: Device to use for computation + + Returns: + pred_track: Predicted tracks + pred_vis: Visibility scores for the tracks + pred_conf: Confidence scores for the tracks + pred_point_3d: 3D points for the tracks + pred_color: Point colors for the tracks (0, 255) + """ + frame_num, _, height, width = images.shape + + query_image = images[query_index] + query_points = extract_keypoints( + query_image, keypoint_extractors, round_keypoints=False + ) + query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] + + # Extract the color at the keypoint locations + query_points_long = query_points.squeeze(0).round().long() + pred_color = images[query_index][ + :, query_points_long[:, 1], query_points_long[:, 0] + ] + pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8) + + # Query the confidence and points_3d at the keypoint locations + if (conf is not None) and (points_3d is not None): + assert height == width + assert conf.shape[-2] == conf.shape[-1] + assert conf.shape[:3] == points_3d.shape[:3] + scale = conf.shape[-1] / width + + query_points_scaled = (query_points.squeeze(0) * scale).round().long() + query_points_scaled = query_points_scaled.cpu().numpy() + + pred_conf = conf[query_index][ + query_points_scaled[:, 1], query_points_scaled[:, 0] + ] + pred_point_3d = points_3d[query_index][ + query_points_scaled[:, 1], query_points_scaled[:, 0] + ] + + # heuristic to remove low confidence points + # should I export this as an input parameter? + valid_mask = pred_conf > 1.2 + if valid_mask.sum() > 512: + query_points = query_points[:, valid_mask] # Make sure shape is compatible + pred_conf = pred_conf[valid_mask] + pred_point_3d = pred_point_3d[valid_mask] + pred_color = pred_color[valid_mask] + else: + pred_conf = None + pred_point_3d = None + + reorder_index = calculate_index_mappings(query_index, frame_num, device=device) + + images_feed, fmaps_feed = switch_tensor_order( + [images, fmaps_for_tracker], reorder_index, dim=0 + ) + images_feed = images_feed[None] # add batch dimension + fmaps_feed = fmaps_feed[None] # add batch dimension + + all_points_num = images_feed.shape[1] * query_points.shape[1] + + # Don't need to be scared, this is just chunking to make GPU happy + if all_points_num > max_points_num: + num_splits = (all_points_num + max_points_num - 1) // max_points_num + query_points = torch.chunk(query_points, num_splits, dim=1) + else: + query_points = [query_points] + + pred_track, pred_vis, _ = predict_tracks_in_chunks( + tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking + ) + + pred_track, pred_vis = switch_tensor_order( + [pred_track, pred_vis], reorder_index, dim=1 + ) + + pred_track = pred_track.squeeze(0).float().cpu().numpy() + pred_vis = pred_vis.squeeze(0).float().cpu().numpy() + + return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color + + +def _augment_non_visible_frames( + pred_tracks: list, # ← running list of np.ndarrays + pred_vis_scores: list, # ← running list of np.ndarrays + pred_confs: list, # ← running list of np.ndarrays for confidence scores + pred_points_3d: list, # ← running list of np.ndarrays for 3D points + pred_colors: list, # ← running list of np.ndarrays for colors + images: torch.Tensor, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num: int, + fine_tracking: bool, + *, + min_vis: int = 500, + non_vis_thresh: float = 0.1, + device: torch.device = None, +): + """ + Augment tracking for frames with insufficient visibility. + + Args: + pred_tracks: List of numpy arrays containing predicted tracks. + pred_vis_scores: List of numpy arrays containing visibility scores. + pred_confs: List of numpy arrays containing confidence scores. + pred_points_3d: List of numpy arrays containing 3D points. + pred_colors: List of numpy arrays containing point colors. + images: Tensor of shape [S, 3, H, W] containing the input images. + conf: Tensor of shape [S, 1, H, W] containing confidence scores + points_3d: Tensor containing 3D points + fmaps_for_tracker: Feature maps for the tracker + keypoint_extractors: Initialized feature extractors + tracker: VGG-SFM tracker + max_points_num: Maximum number of points to process at once + fine_tracking: Whether to use fine tracking + min_vis: Minimum visibility threshold + non_vis_thresh: Non-visibility threshold + device: Device to use for computation + + Returns: + Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists. + """ + last_query = -1 + final_trial = False + cur_extractors = keypoint_extractors # may be replaced on the final trial + + while True: + # Visibility per frame + vis_array = np.concatenate(pred_vis_scores, axis=1) + + # Count frames with sufficient visibility using numpy + sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1) + non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist() + + if len(non_vis_frames) == 0: + break + + print("Processing non visible frames:", non_vis_frames) + + # Decide the frames & extractor for this round + if non_vis_frames[0] == last_query: + # Same frame failed twice - final "all-in" attempt + final_trial = True + cur_extractors = initialize_feature_extractors( + 2048, extractor_method="sp+sift+aliked", device=device + ) + query_frame_list = non_vis_frames # blast them all at once + else: + query_frame_list = [non_vis_frames[0]] # Process one at a time + + last_query = non_vis_frames[0] + + # Run the tracker for every selected frame + for query_index in query_frame_list: + new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + cur_extractors, + tracker, + max_points_num, + fine_tracking, + device, + ) + pred_tracks.append(new_track) + pred_vis_scores.append(new_vis) + pred_confs.append(new_conf) + pred_points_3d.append(new_point_3d) + pred_colors.append(new_color) + + if final_trial: + break # Stop after final attempt + + return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors diff --git a/mapanything/third_party/vggsfm_tracker.py b/mapanything/third_party/vggsfm_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..360aec9a62a1ba44c8b7ddbcc8e8e78ea0a05e97 --- /dev/null +++ b/mapanything/third_party/vggsfm_tracker.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .track_modules.base_track_predictor import BaseTrackerPredictor +from .track_modules.blocks import BasicEncoder, ShallowEncoder +from .track_modules.track_refine import refine_track + + +class TrackerPredictor(nn.Module): + def __init__(self, **extra_args): + super(TrackerPredictor, self).__init__() + """ + Initializes the tracker predictor. + + Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor, + check track_modules/base_track_predictor.py + + Both coarse_fnet and fine_fnet are constructed as a 2D CNN network + check track_modules/blocks.py for BasicEncoder and ShallowEncoder + """ + # Define coarse predictor configuration + coarse_stride = 4 + self.coarse_down_ratio = 2 + + # Create networks directly instead of using instantiate + self.coarse_fnet = BasicEncoder(stride=coarse_stride) + self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride) + + # Create fine predictor with stride = 1 + self.fine_fnet = ShallowEncoder(stride=1) + self.fine_predictor = BaseTrackerPredictor( + stride=1, + depth=4, + corr_levels=3, + corr_radius=3, + latent_dim=32, + hidden_size=256, + fine=True, + use_spaceatt=False, + ) + + def forward( + self, + images, + query_points, + fmaps=None, + coarse_iters=6, + inference=True, + fine_tracking=True, + fine_chunk=40960, + ): + """ + Args: + images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W. + query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2. + fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None. + coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6. + inference (bool, optional): Whether to perform inference. Defaults to True. + fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True. + + Returns: + tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score. + """ + + if fmaps is None: + batch_num, frame_num, image_dim, height, width = images.shape + reshaped_image = images.reshape( + batch_num * frame_num, image_dim, height, width + ) + fmaps = self.process_images_to_fmaps(reshaped_image) + fmaps = fmaps.reshape( + batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1] + ) + + if inference: + torch.cuda.empty_cache() + + # Coarse prediction + coarse_pred_track_lists, pred_vis = self.coarse_predictor( + query_points=query_points, + fmaps=fmaps, + iters=coarse_iters, + down_ratio=self.coarse_down_ratio, + ) + coarse_pred_track = coarse_pred_track_lists[-1] + + if inference: + torch.cuda.empty_cache() + + if fine_tracking: + # Refine the coarse prediction + fine_pred_track, pred_score = refine_track( + images, + self.fine_fnet, + self.fine_predictor, + coarse_pred_track, + compute_score=False, + chunk=fine_chunk, + ) + + if inference: + torch.cuda.empty_cache() + else: + fine_pred_track = coarse_pred_track + pred_score = torch.ones_like(pred_vis) + + return fine_pred_track, coarse_pred_track, pred_vis, pred_score + + def process_images_to_fmaps(self, images): + """ + This function processes images for inference. + + Args: + images (torch.Tensor): The images to be processed with shape S x 3 x H x W. + + Returns: + torch.Tensor: The processed feature maps. + """ + if self.coarse_down_ratio > 1: + # whether or not scale down the input images to save memory + fmaps = self.coarse_fnet( + F.interpolate( + images, + scale_factor=1 / self.coarse_down_ratio, + mode="bilinear", + align_corners=True, + ) + ) + else: + fmaps = self.coarse_fnet(images) + + return fmaps diff --git a/mapanything/third_party/vggsfm_utils.py b/mapanything/third_party/vggsfm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c598ae40c41fb8ad10e14e401352d6342dbfd06 --- /dev/null +++ b/mapanything/third_party/vggsfm_utils.py @@ -0,0 +1,340 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Modified from https://github.com/facebookresearch/vggt + +import logging +import warnings + +import torch +import torch.nn.functional as F +from lightglue import ALIKED, SIFT, SuperPoint + +from .vggsfm_tracker import TrackerPredictor + +# Suppress verbose logging from dependencies +logging.getLogger("dinov2").setLevel(logging.WARNING) +warnings.filterwarnings("ignore", message="xFormers is available") +warnings.filterwarnings("ignore", message="dinov2") + +# Constants +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +def build_vggsfm_tracker(model_path=None): + """ + Build and initialize the VGGSfM tracker. + + Args: + model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace. + + Returns: + Initialized tracker model in eval mode. + """ + tracker = TrackerPredictor() + + if model_path is None: + default_url = ( + "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt" + ) + tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url)) + else: + tracker.load_state_dict(torch.load(model_path)) + + tracker.eval() + return tracker + + +def generate_rank_by_dino( + images, + query_frame_num, + image_size=336, + model_name="dinov2_vitb14_reg", + device="cuda", + spatial_similarity=False, +): + """ + Generate a ranking of frames using DINO ViT features. + + Args: + images: Tensor of shape (S, 3, H, W) with values in range [0, 1] + query_frame_num: Number of frames to select + image_size: Size to resize images to before processing + model_name: Name of the DINO model to use + device: Device to run the model on + spatial_similarity: Whether to use spatial token similarity or CLS token similarity + + Returns: + List of frame indices ranked by their representativeness + """ + # Resize images to the target size + images = F.interpolate( + images, (image_size, image_size), mode="bilinear", align_corners=False + ) + + # Load DINO model + dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) + dino_v2_model.eval() + dino_v2_model = dino_v2_model.to(device) + + # Normalize images using ResNet normalization + resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) + resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) + images_resnet_norm = (images - resnet_mean) / resnet_std + + with torch.no_grad(): + frame_feat = dino_v2_model(images_resnet_norm, is_training=True) + + # Process features based on similarity type + if spatial_similarity: + frame_feat = frame_feat["x_norm_patchtokens"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + + # Compute the similarity matrix + frame_feat_norm = frame_feat_norm.permute(1, 0, 2) + similarity_matrix = torch.bmm( + frame_feat_norm, frame_feat_norm.transpose(-1, -2) + ) + similarity_matrix = similarity_matrix.mean(dim=0) + else: + frame_feat = frame_feat["x_norm_clstoken"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + + distance_matrix = 100 - similarity_matrix.clone() + + # Ignore self-pairing + similarity_matrix.fill_diagonal_(-100) + similarity_sum = similarity_matrix.sum(dim=1) + + # Find the most common frame + most_common_frame_index = torch.argmax(similarity_sum).item() + + # Conduct FPS sampling starting from the most common frame + fps_idx = farthest_point_sampling( + distance_matrix, query_frame_num, most_common_frame_index + ) + + # Clean up all tensors and models to free memory + del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix + del dino_v2_model + torch.cuda.empty_cache() + + return fps_idx + + +def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): + """ + Farthest point sampling algorithm to select diverse frames. + + Args: + distance_matrix: Matrix of distances between frames + num_samples: Number of frames to select + most_common_frame_index: Index of the first frame to select + + Returns: + List of selected frame indices + """ + distance_matrix = distance_matrix.clamp(min=0) + N = distance_matrix.size(0) + + # Initialize with the most common frame + selected_indices = [most_common_frame_index] + check_distances = distance_matrix[selected_indices] + + while len(selected_indices) < num_samples: + # Find the farthest point from the current set of selected points + farthest_point = torch.argmax(check_distances) + selected_indices.append(farthest_point.item()) + + check_distances = distance_matrix[farthest_point] + # Mark already selected points to avoid selecting them again + check_distances[selected_indices] = 0 + + # Break if all points have been selected + if len(selected_indices) == N: + break + + return selected_indices + + +def calculate_index_mappings(query_index, S, device=None): + """ + Construct an order that switches [query_index] and [0] + so that the content of query_index would be placed at [0]. + + Args: + query_index: Index to swap with 0 + S: Total number of elements + device: Device to place the tensor on + + Returns: + Tensor of indices with the swapped order + """ + new_order = torch.arange(S) + new_order[0] = query_index + new_order[query_index] = 0 + if device is not None: + new_order = new_order.to(device) + return new_order + + +def switch_tensor_order(tensors, order, dim=1): + """ + Reorder tensors along a specific dimension according to the given order. + + Args: + tensors: List of tensors to reorder + order: Tensor of indices specifying the new order + dim: Dimension along which to reorder + + Returns: + List of reordered tensors + """ + return [ + torch.index_select(tensor, dim, order) if tensor is not None else None + for tensor in tensors + ] + + +def initialize_feature_extractors( + max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda" +): + """ + Initialize feature extractors that can be reused based on a method string. + + Args: + max_query_num: Maximum number of keypoints to extract + det_thres: Detection threshold for keypoint extraction + extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") + device: Device to run extraction on + + Returns: + Dictionary of initialized extractors + """ + extractors = {} + methods = extractor_method.lower().split("+") + + for method in methods: + method = method.strip() + if method == "aliked": + aliked_extractor = ALIKED( + max_num_keypoints=max_query_num, detection_threshold=det_thres + ) + extractors["aliked"] = aliked_extractor.to(device).eval() + elif method == "sp": + sp_extractor = SuperPoint( + max_num_keypoints=max_query_num, detection_threshold=det_thres + ) + extractors["sp"] = sp_extractor.to(device).eval() + elif method == "sift": + sift_extractor = SIFT(max_num_keypoints=max_query_num) + extractors["sift"] = sift_extractor.to(device).eval() + else: + print(f"Warning: Unknown feature extractor '{method}', ignoring.") + + if not extractors: + print( + f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default." + ) + aliked_extractor = ALIKED( + max_num_keypoints=max_query_num, detection_threshold=det_thres + ) + extractors["aliked"] = aliked_extractor.to(device).eval() + + return extractors + + +def extract_keypoints(query_image, extractors, round_keypoints=True): + """ + Extract keypoints using pre-initialized feature extractors. + + Args: + query_image: Input image tensor (3xHxW, range [0, 1]) + extractors: Dictionary of initialized extractors + + Returns: + Tensor of keypoint coordinates (1xNx2) + """ + query_points = None + + with torch.no_grad(): + for extractor_name, extractor in extractors.items(): + query_points_data = extractor.extract(query_image, invalid_mask=None) + extractor_points = query_points_data["keypoints"] + if round_keypoints: + extractor_points = extractor_points.round() + + if query_points is not None: + query_points = torch.cat([query_points, extractor_points], dim=1) + else: + query_points = extractor_points + + return query_points + + +def predict_tracks_in_chunks( + track_predictor, + images_feed, + query_points_list, + fmaps_feed, + fine_tracking, + num_splits=None, + fine_chunk=40960, +): + """ + Process a list of query points to avoid memory issues. + + Args: + track_predictor (object): The track predictor object used for predicting tracks. + images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images. + query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points. + fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker. + fine_tracking (bool): Whether to perform fine tracking. + num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility. + + Returns: + tuple: A tuple containing the concatenated predicted tracks, visibility, and scores. + """ + # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility + if not isinstance(query_points_list, (list, tuple)): + query_points = query_points_list + if num_splits is None: + num_splits = 1 + query_points_list = torch.chunk(query_points, num_splits, dim=1) + + # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple) + if isinstance(query_points_list, tuple): + query_points_list = list(query_points_list) + + fine_pred_track_list = [] + pred_vis_list = [] + pred_score_list = [] + + for split_points in query_points_list: + # Feed into track predictor for each split + fine_pred_track, _, pred_vis, pred_score = track_predictor( + images_feed, + split_points, + fmaps=fmaps_feed, + fine_tracking=fine_tracking, + fine_chunk=fine_chunk, + ) + fine_pred_track_list.append(fine_pred_track) + pred_vis_list.append(pred_vis) + pred_score_list.append(pred_score) + + # Concatenate the results from all splits + fine_pred_track = torch.cat(fine_pred_track_list, dim=2) + pred_vis = torch.cat(pred_vis_list, dim=2) + + if pred_score is not None: + pred_score = torch.cat(pred_score_list, dim=2) + else: + pred_score = None + + return fine_pred_track, pred_vis, pred_score diff --git a/mapanything/train/__init__.py b/mapanything/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/train/losses.py b/mapanything/train/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..b5203bc7531e8348b4de2fbffc7286c04b643861 --- /dev/null +++ b/mapanything/train/losses.py @@ -0,0 +1,5068 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Multi-view geometric losses for training 3D reconstruction models. + +References: DUSt3R & MASt3R +""" + +import math +from copy import copy, deepcopy + +import einops as ein +import torch +import torch.nn as nn + +from mapanything.utils.geometry import ( + angle_diff_vec3, + apply_log_to_norm, + closed_form_pose_inverse, + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, + geotrf, + normalize_multiple_pointclouds, + quaternion_inverse, + quaternion_multiply, + quaternion_to_rotation_matrix, + transform_pose_using_quats_and_trans_2_to_1, +) + + +def get_loss_terms_and_details( + losses_dict, valid_masks, self_name, n_views, flatten_across_image_only +): + """ + Helper function to generate loss terms and details for different loss types. + + Args: + losses_dict (dict): Dictionary mapping loss types to their values. + Format: { + 'loss_type': { + 'values': list_of_loss_tensors or single_tensor, + 'use_mask': bool, + 'is_multi_view': bool + } + } + valid_masks (list): List of valid masks for each view. + self_name (str): Name of the loss class. + n_views (int): Number of views. + flatten_across_image_only (bool): Whether flattening was done across image only. + + Returns: + tuple: (loss_terms, details) where loss_terms is a list of tuples (loss, mask, type) + and details is a dictionary of loss details. + """ + loss_terms = [] + details = {} + + for loss_type, loss_info in losses_dict.items(): + values = loss_info["values"] + use_mask = loss_info["use_mask"] + is_multi_view = loss_info["is_multi_view"] + if is_multi_view: + # Handle multi-view losses (list of tensors) + view_loss_details = [] + for i in range(n_views): + mask = valid_masks[i] if use_mask else None + loss_terms.append((values[i], mask, loss_type)) + + # Add details for individual view + if not flatten_across_image_only or not use_mask: + values_after_masking = values[i] + else: + values_after_masking = values[i][mask] + + if values_after_masking.numel() > 0: + view_loss_detail = float(values_after_masking.mean()) + if view_loss_detail > 0: + details[f"{self_name}_{loss_type}_view{i + 1}"] = ( + view_loss_detail + ) + view_loss_details.append(view_loss_detail) + # Add average across views + if len(view_loss_details) > 0: + details[f"{self_name}_{loss_type}_avg"] = sum(view_loss_details) / len( + view_loss_details + ) + else: + # Handle single tensor losses + if values is not None: + loss_terms.append((values, None, loss_type)) + if values.numel() > 0: + loss_detail = float(values.mean()) + if loss_detail > 0: + details[f"{self_name}_{loss_type}"] = loss_detail + + return loss_terms, details + + +def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor: + if beta == 0: + return err + else: + return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta) + + +def compute_normal_loss(points, gt_points, mask): + """ + Compute the normal loss between the predicted and ground truth points. + References: + https://github.com/microsoft/MoGe/blob/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/train/losses.py#L205 + + Args: + points (torch.Tensor): Predicted points. Shape: (..., H, W, 3). + gt_points (torch.Tensor): Ground truth points. Shape: (..., H, W, 3). + mask (torch.Tensor): Mask indicating valid points. Shape: (..., H, W). + + Returns: + torch.Tensor: Normal loss. + """ + height, width = points.shape[-3:-1] + + leftup, rightup, leftdown, rightdown = ( + points[..., :-1, :-1, :], + points[..., :-1, 1:, :], + points[..., 1:, :-1, :], + points[..., 1:, 1:, :], + ) + upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1) + leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1) + downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1) + rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1) + + gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = ( + gt_points[..., :-1, :-1, :], + gt_points[..., :-1, 1:, :], + gt_points[..., 1:, :-1, :], + gt_points[..., 1:, 1:, :], + ) + gt_upxleft = torch.cross( + gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1 + ) + gt_leftxdown = torch.cross( + gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1 + ) + gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1) + gt_rightxup = torch.cross( + gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1 + ) + + mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = ( + mask[..., :-1, :-1], + mask[..., :-1, 1:], + mask[..., 1:, :-1], + mask[..., 1:, 1:], + ) + mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown + mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup + mask_downxright = mask_leftdown & mask_rightup & mask_leftup + mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown + + MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3) + + loss = ( + mask_upxleft + * _smooth( + angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE), + beta=BETA_RAD, + ) + + mask_leftxdown + * _smooth( + angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE), + beta=BETA_RAD, + ) + + mask_downxright + * _smooth( + angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE), + beta=BETA_RAD, + ) + + mask_rightxup + * _smooth( + angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE), + beta=BETA_RAD, + ) + ) + + total_valid_mask = mask_upxleft | mask_leftxdown | mask_downxright | mask_rightxup + valid_count = total_valid_mask.sum() + if valid_count > 0: + loss = loss.sum() / (valid_count * (4 * max(points.shape[-3:-1]))) + else: + loss = 0 * loss.sum() + + return loss + + +def compute_gradient_loss(prediction, gt_target, mask): + """ + Compute the gradient loss between the prediction and GT target at valid points. + References: + https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss + https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py + + Args: + prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C). + gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C). + mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W). + """ + # Expand mask to match number of channels in prediction + mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1]) + summed_mask = torch.sum(mask, (1, 2, 3)) + + # Compute the gradient of the prediction and GT target + diff = prediction - gt_target + diff = torch.mul(mask, diff) + + # Gradient in x direction + grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) + mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) + grad_x = torch.mul(mask_x, grad_x) + + # Gradient in y direction + grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) + mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) + grad_y = torch.mul(mask_y, grad_y) + + # Clamp the outlier gradients + grad_x = grad_x.clamp(max=100) + grad_y = grad_y.clamp(max=100) + + # Compute the total loss + image_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3)) + num_valid_pixels = torch.sum(summed_mask) + if num_valid_pixels > 0: + image_loss = torch.sum(image_loss) / num_valid_pixels + else: + image_loss = 0 * torch.sum(image_loss) + + return image_loss + + +def compute_gradient_matching_loss(prediction, gt_target, mask, scales=4): + """ + Compute the multi-scale gradient matching loss between the prediction and GT target at valid points. + This loss biases discontinuities to be sharp and to coincide with discontinuities in the ground truth. + More info in MiDAS: https://arxiv.org/pdf/1907.01341.pdf; Equation 11 + References: + https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss + https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py + + Args: + prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C). + gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C). + mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W). + scales (int): Number of scales to compute the loss at. Default: 4. + """ + # Define total loss + total_loss = 0.0 + + # Compute the gradient loss at different scales + for scale in range(scales): + step = pow(2, scale) + grad_loss = compute_gradient_loss( + prediction[:, ::step, ::step], + gt_target[:, ::step, ::step], + mask[:, ::step, ::step], + ) + total_loss += grad_loss + + return total_loss + + +def Sum(*losses_and_masks): + """ + Aggregates multiple losses into a single loss value or returns the original losses. + + Args: + *losses_and_masks: Variable number of tuples, each containing (loss, mask, rep_type) + - loss: Tensor containing loss values + - mask: Mask indicating valid pixels/regions + - rep_type: String indicating the type of representation (e.g., 'pts3d', 'depth') + + Returns: + If the first loss has dimensions > 0: + Returns the original list of (loss, mask, rep_type) tuples + Otherwise: + Returns a scalar tensor that is the sum of all loss values + """ + loss, mask, rep_type = losses_and_masks[0] + if loss.ndim > 0: + # we are actually returning the loss for every pixels + return losses_and_masks + else: + # we are returning the global loss + for loss2, mask2, rep_type2 in losses_and_masks[1:]: + loss = loss + loss2 + return loss + + +class BaseCriterion(nn.Module): + "Base Criterion to support different reduction methods" + + def __init__(self, reduction="mean"): + super().__init__() + self.reduction = reduction + + +class LLoss(BaseCriterion): + "L-norm loss" + + def forward(self, a, b, **kwargs): + assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 4, ( + f"Bad shape = {a.shape}" + ) + dist = self.distance(a, b, **kwargs) + assert dist.ndim == a.ndim - 1 # one dimension less + if self.reduction == "none": + return dist + if self.reduction == "sum": + return dist.sum() + if self.reduction == "mean": + return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) + raise ValueError(f"bad {self.reduction=} mode") + + def distance(self, a, b, **kwargs): + raise NotImplementedError() + + +class L1Loss(LLoss): + "L1 distance" + + def distance(self, a, b, **kwargs): + return torch.abs(a - b).sum(dim=-1) + + +class L2Loss(LLoss): + "Euclidean (L2 Norm) distance" + + def distance(self, a, b, **kwargs): + return torch.norm(a - b, dim=-1) + + +class GenericLLoss(LLoss): + "Criterion that supports different L-norms" + + def distance(self, a, b, loss_type, **kwargs): + if loss_type == "l1": + # L1 distance + return torch.abs(a - b).sum(dim=-1) + elif loss_type == "l2": + # Euclidean (L2 norm) distance + return torch.norm(a - b, dim=-1) + else: + raise ValueError( + f"Unsupported loss type: {loss_type}. Supported types are 'l1' and 'l2'." + ) + + +class FactoredLLoss(LLoss): + "Criterion that supports different L-norms for the factored loss functions" + + def __init__( + self, + reduction="mean", + points_loss_type="l2", + depth_loss_type="l1", + ray_directions_loss_type="l1", + pose_quats_loss_type="l1", + pose_trans_loss_type="l1", + scale_loss_type="l1", + ): + super().__init__(reduction) + self.points_loss_type = points_loss_type + self.depth_loss_type = depth_loss_type + self.ray_directions_loss_type = ray_directions_loss_type + self.pose_quats_loss_type = pose_quats_loss_type + self.pose_trans_loss_type = pose_trans_loss_type + self.scale_loss_type = scale_loss_type + + def _distance(self, a, b, loss_type): + if loss_type == "l1": + # L1 distance + return torch.abs(a - b).sum(dim=-1) + elif loss_type == "l2": + # Euclidean (L2 norm) distance + return torch.norm(a - b, dim=-1) + else: + raise ValueError(f"Unsupported loss type: {loss_type}.") + + def distance(self, a, b, factor, **kwargs): + if factor == "points": + return self._distance(a, b, self.points_loss_type) + elif factor == "depth": + return self._distance(a, b, self.depth_loss_type) + elif factor == "ray_directions": + return self._distance(a, b, self.ray_directions_loss_type) + elif factor == "pose_quats": + return self._distance(a, b, self.pose_quats_loss_type) + elif factor == "pose_trans": + return self._distance(a, b, self.pose_trans_loss_type) + elif factor == "scale": + return self._distance(a, b, self.scale_loss_type) + else: + raise ValueError(f"Unsupported factor type: {factor}.") + + +class RobustRegressionLoss(LLoss): + """ + Generalized Robust Loss introduced in https://arxiv.org/abs/1701.03077. + """ + + def __init__(self, alpha=0.5, scaling_c=0.25, reduction="mean"): + """ + Initialize the Robust Regression Loss. + + Args: + alpha (float): Shape parameter controlling the robustness of the loss. + Lower values make the loss more robust to outliers. Default: 0.5. + scaling_c (float): Scale parameter controlling the transition between + quadratic and robust behavior. Default: 0.1. + reduction (str): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + super().__init__(reduction) + self.alpha = alpha + self.scaling_c = scaling_c + + def distance(self, a, b, **kwargs): + error_scaled = torch.sum(((a - b) / self.scaling_c) ** 2, dim=-1) + robust_loss = (abs(self.alpha - 2) / self.alpha) * ( + torch.pow((error_scaled / abs(self.alpha - 2)) + 1, self.alpha / 2) - 1 + ) + return robust_loss + + +class BCELoss(BaseCriterion): + """Binary Cross Entropy loss""" + + def forward(self, predicted_logits, reference_mask): + """ + Args: + predicted_logits: (B, H, W) tensor of predicted logits for the mask + reference_mask: (B, H, W) tensor of reference mask + + Returns: + loss: scalar tensor of the BCE loss + """ + bce_loss = torch.nn.functional.binary_cross_entropy_with_logits( + predicted_logits, reference_mask.float() + ) + + return bce_loss + + +class Criterion(nn.Module): + """ + Base class for all criterion modules that wrap a BaseCriterion. + + This class serves as a wrapper around BaseCriterion objects, providing + additional functionality like naming and reduction mode control. + + Args: + criterion (BaseCriterion): The base criterion to wrap. + """ + + def __init__(self, criterion=None): + super().__init__() + assert isinstance(criterion, BaseCriterion), ( + f"{criterion} is not a proper criterion!" + ) + self.criterion = copy(criterion) + + def get_name(self): + """ + Returns a string representation of this criterion. + + Returns: + str: A string containing the class name and the wrapped criterion. + """ + return f"{type(self).__name__}({self.criterion})" + + def with_reduction(self, mode="none"): + """ + Creates a deep copy of this criterion with the specified reduction mode. + + This method recursively sets the reduction mode for this criterion and + any chained MultiLoss criteria. + + Args: + mode (str): The reduction mode to set. Default: "none". + + Returns: + Criterion: A new criterion with the specified reduction mode. + """ + res = loss = deepcopy(self) + while loss is not None: + assert isinstance(loss, Criterion) + loss.criterion.reduction = mode # make it return the loss for each sample + loss = loss._loss2 # we assume loss is a Multiloss + return res + + +class MultiLoss(nn.Module): + """ + Base class for combinable loss functions with automatic tracking of individual loss values. + + This class enables easy combination of multiple loss functions through arithmetic operations: + loss = MyLoss1() + 0.1*MyLoss2() + + The combined loss functions maintain their individual weights and the forward pass + automatically computes and aggregates all losses while tracking individual loss values. + + Usage: + Inherit from this class and override get_name() and compute_loss() methods. + + Attributes: + _alpha (float): Weight multiplier for this loss component. + _loss2 (MultiLoss): Reference to the next loss in the chain, if any. + """ + + def __init__(self): + """Initialize the MultiLoss with default weight of 1 and no chained loss.""" + super().__init__() + self._alpha = 1 + self._loss2 = None + + def compute_loss(self, *args, **kwargs): + """ + Compute the loss value for this specific loss component. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + torch.Tensor or tuple: Either the loss tensor or a tuple of (loss, details_dict). + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() + + def get_name(self): + """ + Get the name of this loss component. + + Returns: + str: The name of the loss. + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() + + def __mul__(self, alpha): + """ + Multiply the loss by a scalar weight. + + Args: + alpha (int or float): The weight to multiply the loss by. + + Returns: + MultiLoss: A new loss object with the updated weight. + + Raises: + AssertionError: If alpha is not a number. + """ + assert isinstance(alpha, (int, float)) + res = copy(self) + res._alpha = alpha + return res + + __rmul__ = __mul__ # Support both loss*alpha and alpha*loss + + def __add__(self, loss2): + """ + Add another loss to this loss, creating a chain of losses. + + Args: + loss2 (MultiLoss): Another loss to add to this one. + + Returns: + MultiLoss: A new loss object representing the combined losses. + + Raises: + AssertionError: If loss2 is not a MultiLoss. + """ + assert isinstance(loss2, MultiLoss) + res = cur = copy(self) + # Find the end of the chain + while cur._loss2 is not None: + cur = cur._loss2 + cur._loss2 = loss2 + return res + + def __repr__(self): + """ + Create a string representation of the loss, including weights and chained losses. + + Returns: + str: String representation of the loss. + """ + name = self.get_name() + if self._alpha != 1: + name = f"{self._alpha:g}*{name}" + if self._loss2: + name = f"{name} + {self._loss2}" + return name + + def forward(self, *args, **kwargs): + """ + Compute the weighted loss and aggregate with any chained losses. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + tuple: A tuple containing: + - torch.Tensor: The total weighted loss. + - dict: Details about individual loss components. + """ + loss = self.compute_loss(*args, **kwargs) + if isinstance(loss, tuple): + loss, details = loss + elif loss.ndim == 0: + details = {self.get_name(): float(loss)} + else: + details = {} + loss = loss * self._alpha + + if self._loss2: + loss2, details2 = self._loss2(*args, **kwargs) + loss = loss + loss2 + details |= details2 + + return loss, details + + +class NonAmbiguousMaskLoss(Criterion, MultiLoss): + """ + Loss on non-ambiguous mask prediction logits. + """ + + def __init__(self, criterion): + super().__init__(criterion) + + def compute_loss(self, batch, preds, **kw): + """ + Args: + batch: list of dicts with the gt data + preds: list of dicts with the predictions + + Returns: + loss: Sum class of the lossses for N-views and the loss details + """ + # Init loss list to keep track of individual losses for each view + loss_list = [] + mask_loss_details = {} + mask_loss_total = 0 + self_name = type(self).__name__ + + # Loop over the views + for view_idx, (gt, pred) in enumerate(zip(batch, preds)): + # Get the GT non-ambiguous masks + gt_non_ambiguous_mask = gt["non_ambiguous_mask"] + + # Get the predicted non-ambiguous mask logits + pred_non_ambiguous_mask_logits = pred["non_ambiguous_mask_logits"] + + # Compute the loss for the current view + loss = self.criterion(pred_non_ambiguous_mask_logits, gt_non_ambiguous_mask) + + # Add the loss to the list + loss_list.append((loss, None, "non_ambiguous_mask")) + + # Add the loss details to the dictionary + mask_loss_details[f"{self_name}_mask_view{view_idx + 1}"] = float(loss) + mask_loss_total += float(loss) + + # Compute the average loss across all views + mask_loss_details[f"{self_name}_mask_avg"] = mask_loss_total / len(batch) + + return Sum(*loss_list), (mask_loss_details | {}) + + +class ConfLoss(MultiLoss): + """ + Applies confidence-weighted regression loss using model-predicted confidence values. + + The confidence-weighted loss has the form: + conf_loss = raw_loss * conf - alpha * log(conf) + + Where: + - raw_loss is the original per-pixel loss + - conf is the predicted confidence (higher values = higher confidence) + - alpha is a hyperparameter controlling the regularization strength + + This loss can be selectively applied to specific loss components in factored and multi-view settings. + """ + + def __init__(self, pixel_loss, alpha=1, loss_set_indices=None): + """ + Args: + pixel_loss (MultiLoss): The pixel-level regression loss to be used. + alpha (float): Hyperparameter controlling the confidence regularization strength. + loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to. + Each index selects a specific loss set across all views (with the same rep_type). + If None, defaults to [0] which applies to the first loss set only. + """ + super().__init__() + assert alpha > 0 + self.alpha = alpha + self.pixel_loss = pixel_loss.with_reduction("none") + self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices + + def get_name(self): + return f"ConfLoss({self.pixel_loss})" + + def get_conf_log(self, x): + return x, torch.log(x) + + def compute_loss(self, batch, preds, **kw): + # Init loss list and details + total_loss = 0 + conf_loss_details = {} + running_avg_dict = {} + self_name = type(self.pixel_loss).__name__ + n_views = len(batch) + + # Compute per-pixel loss for each view + losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw) + + # Select specific loss sets based on indices + selected_losses = [] + processed_indices = set() + for idx in self.loss_set_indices: + start_idx = idx * n_views + end_idx = min((idx + 1) * n_views, len(losses)) + selected_losses.extend(losses[start_idx:end_idx]) + processed_indices.update(range(start_idx, end_idx)) + + # Process selected losses with confidence weighting + for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses): + view_idx = loss_idx % n_views # Map to corresponding view index + + if loss.numel() == 0: + # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) + continue + + # Get the confidence and log confidence + if ( + hasattr(self.pixel_loss, "flatten_across_image_only") + and self.pixel_loss.flatten_across_image_only + ): + # Reshape confidence to match the flattened dimensions + conf_reshaped = preds[view_idx]["conf"].view( + preds[view_idx]["conf"].shape[0], -1 + ) + conf, log_conf = self.get_conf_log(conf_reshaped[msk]) + loss = loss[msk] + else: + conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk]) + + # Weight the loss by the confidence + conf_loss = loss * conf - self.alpha * log_conf + + # Only add to total loss and store details if there are valid elements + if conf_loss.numel() > 0: + conf_loss = conf_loss.mean() + total_loss = total_loss + conf_loss + + # Store details + conf_loss_details[ + f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}" + ] = float(conf_loss) + + # Initialize or update running average directly + avg_key = f"{self_name}_{rep_type}_conf_loss_avg" + if avg_key not in conf_loss_details: + conf_loss_details[avg_key] = float(conf_loss) + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] = 1 + else: + valid_views = ( + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] + + 1 + ) + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] = valid_views + conf_loss_details[avg_key] += ( + float(conf_loss) - conf_loss_details[avg_key] + ) / valid_views + + # Add unmodified losses for sets not in selected_losses + for idx, (loss, msk, rep_type) in enumerate(losses): + if idx not in processed_indices: + if msk is not None: + loss_after_masking = loss[msk] + else: + loss_after_masking = loss + if loss_after_masking.numel() > 0: + loss_mean = loss_after_masking.mean() + else: + # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) + loss_mean = 0 + total_loss = total_loss + loss_mean + + return total_loss, dict(**conf_loss_details, **pixel_loss_details) + + +class ExcludeTopNPercentPixelLoss(MultiLoss): + """ + Pixel-level regression loss where for each instance in a batch the top N% of per-pixel loss values are ignored + for the mean loss computation. + Allows selecting which pixel-level regression loss sets to apply the exclusion to. + """ + + def __init__( + self, + pixel_loss, + top_n_percent=5, + apply_to_real_data_only=True, + loss_set_indices=None, + ): + """ + Args: + pixel_loss (MultiLoss): The pixel-level regression loss to be used. + top_n_percent (float): The percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5. + apply_to_real_data_only (bool): Whether to apply the loss only to real world data. Default: True. + loss_set_indices (list or None): Indices of the loss sets to apply the exclusion to. + Each index selects a specific loss set across all views (with the same rep_type). + If None, defaults to [0] which applies to the first loss set only. + """ + super().__init__() + self.pixel_loss = pixel_loss.with_reduction("none") + self.top_n_percent = top_n_percent + self.bottom_n_percent = 100 - top_n_percent + self.apply_to_real_data_only = apply_to_real_data_only + self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices + + def get_name(self): + return f"ExcludeTopNPercentPixelLoss({self.pixel_loss})" + + def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent): + """ + Function to compute the mask for keeping the bottom n percent of per-pixel loss values. + + Args: + tensor (torch.Tensor): The tensor containing the per-pixel loss values. + Shape: (B, N) where B is the batch size and N is the number of total pixels. + mask (torch.Tensor): The mask indicating valid pixels. Shape: (B, N). + + Returns: + torch.Tensor: Flattened tensor containing the bottom n percent of per-pixel loss values. + """ + B, N = tensor.shape + + # Calculate the number of valid elements (where mask is True) + num_valid = mask.sum(dim=1) + + # Calculate the number of elements to keep (n% of valid elements) + num_keep = (num_valid * bottom_n_percent / 100).long() + + # Create a mask for the bottom n% elements + keep_mask = torch.arange(N, device=tensor.device).unsqueeze( + 0 + ) < num_keep.unsqueeze(1) + + # Create a tensor with inf where mask is False + masked_tensor = torch.where( + mask, tensor, torch.tensor(float("inf"), device=tensor.device) + ) + + # Sort the masked tensor along the N dimension + sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False) + + # Get the bottom n% elements + bottom_n_percent_elements = sorted_tensor[keep_mask] + + return bottom_n_percent_elements + + def compute_loss(self, batch, preds, **kw): + # Compute per-pixel loss + losses, details = self.pixel_loss(batch, preds, **kw) + n_views = len(batch) + + # Select specific loss sets based on indices + selected_losses = [] + processed_indices = set() + for idx in self.loss_set_indices: + start_idx = idx * n_views + end_idx = min((idx + 1) * n_views, len(losses)) + selected_losses.extend(losses[start_idx:end_idx]) + processed_indices.update(range(start_idx, end_idx)) + + # Initialize total loss + total_loss = 0.0 + loss_details = {} + running_avg_dict = {} + self_name = type(self.pixel_loss).__name__ + + # Process selected losses with top N percent exclusion + for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses): + view_idx = loss_idx % n_views # Map to corresponding view index + + if loss.numel() == 0: + # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) + continue + + # Create empty list for current view's aggregated tensors + aggregated_losses = [] + + if self.apply_to_real_data_only: + # Get the synthetic and real world data mask + synthetic_mask = batch[view_idx]["is_synthetic"] + real_data_mask = ~batch[view_idx]["is_synthetic"] + else: + # Apply the filtering to all data + synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"]) + real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"]) + + # Process synthetic data + if synthetic_mask.any(): + synthetic_loss = loss[synthetic_mask] + synthetic_msk = msk[synthetic_mask] + aggregated_losses.append(synthetic_loss[synthetic_msk]) + + # Process real data + if real_data_mask.any(): + real_loss = loss[real_data_mask] + real_msk = msk[real_data_mask] + real_bottom_n_percent_loss = self.keep_bottom_n_percent( + real_loss, real_msk, self.bottom_n_percent + ) + aggregated_losses.append(real_bottom_n_percent_loss) + + # Compute view loss + view_loss = torch.cat(aggregated_losses, dim=0) + + # Only add to total loss and store details if there are valid elements + if view_loss.numel() > 0: + view_loss = view_loss.mean() + total_loss = total_loss + view_loss + + # Store details + loss_details[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}" + ] = float(view_loss) + + # Initialize or update running average directly + avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg" + if avg_key not in loss_details: + loss_details[avg_key] = float(view_loss) + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] = 1 + else: + valid_views = ( + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] + + 1 + ) + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] = valid_views + loss_details[avg_key] += ( + float(view_loss) - loss_details[avg_key] + ) / valid_views + + # Add unmodified losses for sets not in selected_losses + for idx, (loss, msk, rep_type) in enumerate(losses): + if idx not in processed_indices: + if msk is not None: + loss_after_masking = loss[msk] + else: + loss_after_masking = loss + if loss_after_masking.numel() > 0: + loss_mean = loss_after_masking.mean() + else: + # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) + loss_mean = 0 + total_loss = total_loss + loss_mean + + return total_loss, dict(**loss_details, **details) + + +class ConfAndExcludeTopNPercentPixelLoss(MultiLoss): + """ + Combined loss that applies ConfLoss to one set of pixel-level regression losses + and ExcludeTopNPercentPixelLoss to another set of pixel-level regression losses. + """ + + def __init__( + self, + pixel_loss, + conf_alpha=1, + top_n_percent=5, + apply_to_real_data_only=True, + conf_loss_set_indices=None, + exclude_loss_set_indices=None, + ): + """ + Args: + pixel_loss (MultiLoss): The pixel-level regression loss to be used. + conf_alpha (float): Alpha parameter for ConfLoss. Default: 1. + top_n_percent (float): Percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5. + apply_to_real_data_only (bool): Whether to apply the exclude loss only to real world data. Default: True. + conf_loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to. + Each index selects a specific loss set across all views (with the same rep_type). + If None, defaults to [0] which applies to the first loss set only. + exclude_loss_set_indices (list or None): Indices of the loss sets to apply top N percent exclusion to. + Each index selects a specific loss set across all views (with the same rep_type). + If None, defaults to [1] which applies to the second loss set only. + """ + super().__init__() + self.pixel_loss = pixel_loss.with_reduction("none") + assert conf_alpha > 0 + self.conf_alpha = conf_alpha + self.top_n_percent = top_n_percent + self.bottom_n_percent = 100 - top_n_percent + self.apply_to_real_data_only = apply_to_real_data_only + self.conf_loss_set_indices = ( + [0] if conf_loss_set_indices is None else conf_loss_set_indices + ) + self.exclude_loss_set_indices = ( + [1] if exclude_loss_set_indices is None else exclude_loss_set_indices + ) + + def get_name(self): + return f"ConfAndExcludeTopNPercentPixelLoss({self.pixel_loss})" + + def get_conf_log(self, x): + return x, torch.log(x) + + def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent): + """ + Function to compute the mask for keeping the bottom n percent of per-pixel loss values. + """ + B, N = tensor.shape + + # Calculate the number of valid elements (where mask is True) + num_valid = mask.sum(dim=1) + + # Calculate the number of elements to keep (n% of valid elements) + num_keep = (num_valid * bottom_n_percent / 100).long() + + # Create a mask for the bottom n% elements + keep_mask = torch.arange(N, device=tensor.device).unsqueeze( + 0 + ) < num_keep.unsqueeze(1) + + # Create a tensor with inf where mask is False + masked_tensor = torch.where( + mask, tensor, torch.tensor(float("inf"), device=tensor.device) + ) + + # Sort the masked tensor along the N dimension + sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False) + + # Get the bottom n% elements + bottom_n_percent_elements = sorted_tensor[keep_mask] + + return bottom_n_percent_elements + + def compute_loss(self, batch, preds, **kw): + # Compute per-pixel loss + losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw) + n_views = len(batch) + + # Select specific loss sets for confidence weighting + conf_selected_losses = [] + conf_processed_indices = set() + for idx in self.conf_loss_set_indices: + start_idx = idx * n_views + end_idx = min((idx + 1) * n_views, len(losses)) + conf_selected_losses.extend(losses[start_idx:end_idx]) + conf_processed_indices.update(range(start_idx, end_idx)) + + # Select specific loss sets for top N percent exclusion + exclude_selected_losses = [] + exclude_processed_indices = set() + for idx in self.exclude_loss_set_indices: + start_idx = idx * n_views + end_idx = min((idx + 1) * n_views, len(losses)) + exclude_selected_losses.extend(losses[start_idx:end_idx]) + exclude_processed_indices.update(range(start_idx, end_idx)) + + # Initialize total loss and details + total_loss = 0 + loss_details = {} + running_avg_dict = {} + self_name = type(self.pixel_loss).__name__ + + # Process selected losses with confidence weighting + for loss_idx, (loss, msk, rep_type) in enumerate(conf_selected_losses): + view_idx = loss_idx % n_views # Map to corresponding view index + + if loss.numel() == 0: + # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for conf loss", force=True) + continue + + # Get the confidence and log confidence + if ( + hasattr(self.pixel_loss, "flatten_across_image_only") + and self.pixel_loss.flatten_across_image_only + ): + # Reshape confidence to match the flattened dimensions + conf_reshaped = preds[view_idx]["conf"].view( + preds[view_idx]["conf"].shape[0], -1 + ) + conf, log_conf = self.get_conf_log(conf_reshaped[msk]) + loss = loss[msk] + else: + conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk]) + + # Weight the loss by the confidence + conf_loss = loss * conf - self.conf_alpha * log_conf + + # Only add to total loss and store details if there are valid elements + if conf_loss.numel() > 0: + conf_loss = conf_loss.mean() + total_loss = total_loss + conf_loss + + # Store details + loss_details[f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}"] = ( + float(conf_loss) + ) + + # Initialize or update running average directly + avg_key = f"{self_name}_{rep_type}_conf_loss_avg" + if avg_key not in loss_details: + loss_details[avg_key] = float(conf_loss) + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] = 1 + else: + valid_views = ( + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] + + 1 + ) + running_avg_dict[ + f"{self_name}_{rep_type}_conf_loss_valid_views" + ] = valid_views + loss_details[avg_key] += ( + float(conf_loss) - loss_details[avg_key] + ) / valid_views + + # Process selected losses with top N percent exclusion + for loss_idx, (loss, msk, rep_type) in enumerate(exclude_selected_losses): + view_idx = loss_idx % n_views # Map to corresponding view index + + if loss.numel() == 0: + # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for exclude loss", force=True) + continue + + # Create empty list for current view's aggregated tensors + aggregated_losses = [] + + if self.apply_to_real_data_only: + # Get the synthetic and real world data mask + synthetic_mask = batch[view_idx]["is_synthetic"] + real_data_mask = ~batch[view_idx]["is_synthetic"] + else: + # Apply the filtering to all data + synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"]) + real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"]) + + # Process synthetic data + if synthetic_mask.any(): + synthetic_loss = loss[synthetic_mask] + synthetic_msk = msk[synthetic_mask] + aggregated_losses.append(synthetic_loss[synthetic_msk]) + + # Process real data + if real_data_mask.any(): + real_loss = loss[real_data_mask] + real_msk = msk[real_data_mask] + real_bottom_n_percent_loss = self.keep_bottom_n_percent( + real_loss, real_msk, self.bottom_n_percent + ) + aggregated_losses.append(real_bottom_n_percent_loss) + + # Compute view loss + view_loss = torch.cat(aggregated_losses, dim=0) + + # Only add to total loss and store details if there are valid elements + if view_loss.numel() > 0: + view_loss = view_loss.mean() + total_loss = total_loss + view_loss + + # Store details + loss_details[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}" + ] = float(view_loss) + + # Initialize or update running average directly + avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg" + if avg_key not in loss_details: + loss_details[avg_key] = float(view_loss) + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] = 1 + else: + valid_views = ( + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] + + 1 + ) + running_avg_dict[ + f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" + ] = valid_views + loss_details[avg_key] += ( + float(view_loss) - loss_details[avg_key] + ) / valid_views + + # Add unmodified losses for sets not processed with either confidence or exclusion + all_processed_indices = conf_processed_indices.union(exclude_processed_indices) + for idx, (loss, msk, rep_type) in enumerate(losses): + if idx not in all_processed_indices: + if msk is not None: + loss_after_masking = loss[msk] + else: + loss_after_masking = loss + if loss_after_masking.numel() > 0: + loss_mean = loss_after_masking.mean() + else: + # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) + loss_mean = 0 + total_loss = total_loss + loss_mean + + return total_loss, dict(**loss_details, **pixel_loss_details) + + +class Regr3D(Criterion, MultiLoss): + """ + Regression Loss for World Frame Pointmaps. + Asymmetric loss where view 1 is supposed to be the anchor. + + For each view i: + Pi = RTi @ Di + lossi = (RTi1 @ pred_Di) - (RT1^-1 @ RTi @ Di) + where RT1 is the anchor view camera pose + """ + + def __init__( + self, + criterion, + norm_mode="?avg_dis", + gt_scale=False, + ambiguous_loss_value=0, + max_metric_scale=False, + loss_in_log=True, + flatten_across_image_only=False, + ): + """ + Initialize the loss criterion for World Frame Pointmaps. + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis". + If prefixed with "?", normalization is only applied to non-metric scale data. + gt_scale (bool): If True, enforce predictions to have the same scale as ground truth. + If False, both GT and predictions are normalized independently. Default: False. + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this + value, it will be treated as non-metric. Default: False (no limit). + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for pointmaps. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + """ + super().__init__(criterion) + if norm_mode.startswith("?"): + # Do no norm pts from metric scale datasets + self.norm_all = False + self.norm_mode = norm_mode[1:] + else: + self.norm_all = True + self.norm_mode = norm_mode + self.gt_scale = gt_scale + self.ambiguous_loss_value = ambiguous_loss_value + self.max_metric_scale = max_metric_scale + self.loss_in_log = loss_in_log + self.flatten_across_image_only = flatten_across_image_only + + def get_all_info(self, batch, preds, dist_clip=None): + n_views = len(batch) + in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) + + # Initialize lists to store points and masks + no_norm_gt_pts = [] + valid_masks = [] + + # Process ground truth points and valid masks + for view_idx in range(n_views): + no_norm_gt_pts.append( + geotrf(in_camera0, batch[view_idx]["pts3d"]) + ) # B,H,W,3 + valid_masks.append(batch[view_idx]["valid_mask"].clone()) + + if dist_clip is not None: + # Points that are too far-away == invalid + for view_idx in range(n_views): + dis = no_norm_gt_pts[view_idx].norm(dim=-1) # (B, H, W) + valid_masks[view_idx] = valid_masks[view_idx] & (dis <= dist_clip) + + # Get predicted points + no_norm_pr_pts = [] + for view_idx in range(n_views): + no_norm_pr_pts.append(preds[view_idx]["pts3d"]) + + if not self.norm_all: + if self.max_metric_scale: + B = valid_masks[0].shape[0] + # Calculate distances to camera for all views + dists_to_cam1 = [] + for view_idx in range(n_views): + dist = torch.where( + valid_masks[view_idx], + torch.norm(no_norm_gt_pts[view_idx], dim=-1), + 0, + ).view(B, -1) + dists_to_cam1.append(dist) + + # Update metric scale flags + metric_scale_mask = batch[0]["is_metric_scale"] + for dist in dists_to_cam1: + metric_scale_mask = metric_scale_mask & ( + dist.max(dim=-1).values < self.max_metric_scale + ) + + for view_idx in range(n_views): + batch[view_idx]["is_metric_scale"] = metric_scale_mask + + non_metric_scale_mask = ~batch[0]["is_metric_scale"] + else: + non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"]) + + # Initialize normalized points + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + + # Normalize 3d points + if self.norm_mode and non_metric_scale_mask.any(): + normalized_pr_pts = normalize_multiple_pointclouds( + [pts[non_metric_scale_mask] for pts in no_norm_pr_pts], + [mask[non_metric_scale_mask] for mask in valid_masks], + self.norm_mode, + ) + for i in range(n_views): + pr_pts[i][non_metric_scale_mask] = normalized_pr_pts[i] + elif non_metric_scale_mask.any(): + for i in range(n_views): + pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][ + non_metric_scale_mask + ] + + if self.norm_mode and not self.gt_scale: + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + normalized_gt_pts = gt_normalization_output[:-1] + norm_factor = gt_normalization_output[-1] + for i in range(n_views): + gt_pts[i] = normalized_gt_pts[i] + pr_pts[i][~non_metric_scale_mask] = ( + no_norm_pr_pts[i][~non_metric_scale_mask] + / norm_factor[~non_metric_scale_mask] + ) + elif ~non_metric_scale_mask.any(): + for i in range(n_views): + gt_pts[i] = no_norm_gt_pts[i] + pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][ + ~non_metric_scale_mask + ] + else: + for i in range(n_views): + gt_pts[i] = no_norm_gt_pts[i] + + # Get ambiguous masks + ambiguous_masks = [] + for view_idx in range(n_views): + ambiguous_masks.append( + (~batch[view_idx]["non_ambiguous_mask"]) & (~valid_masks[view_idx]) + ) + + return gt_pts, pr_pts, valid_masks, ambiguous_masks, {} + + def compute_loss(self, batch, preds, **kw): + gt_pts, pred_pts, masks, ambiguous_masks, monitoring = self.get_all_info( + batch, preds, **kw + ) + n_views = len(batch) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixels as "valid" pixels + masks = [mask | amb_mask for mask, amb_mask in zip(masks, ambiguous_masks)] + + losses = [] + details = {} + running_avg_dict = {} + self_name = type(self).__name__ + + if not self.flatten_across_image_only: + for view_idx in range(n_views): + pred = pred_pts[view_idx][masks[view_idx]] + gt = gt_pts[view_idx][masks[view_idx]] + + if self.loss_in_log: + pred = apply_log_to_norm(pred) + gt = apply_log_to_norm(gt) + + loss = self.criterion(pred, gt) + + if self.ambiguous_loss_value > 0: + loss = torch.where( + ambiguous_masks[view_idx][masks[view_idx]], + self.ambiguous_loss_value, + loss, + ) + + losses.append((loss, masks[view_idx], "pts3d")) + if loss.numel() > 0: + loss_mean = float(loss.mean()) + details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean + # Initialize or update running average directly + avg_key = f"{self_name}_pts3d_avg" + if avg_key not in details: + details[avg_key] = loss_mean + running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1 + else: + valid_views = ( + running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1 + ) + running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views + details[avg_key] += (loss_mean - details[avg_key]) / valid_views + else: + batch_size, _, _, dim = gt_pts[0].shape + + for view_idx in range(n_views): + gt = gt_pts[view_idx].view(batch_size, -1, dim) + pred = pred_pts[view_idx].view(batch_size, -1, dim) + view_mask = masks[view_idx].view(batch_size, -1) + amb_mask = ambiguous_masks[view_idx].view(batch_size, -1) + + if self.loss_in_log: + pred = apply_log_to_norm(pred) + gt = apply_log_to_norm(gt) + + loss = self.criterion(pred, gt) + + if self.ambiguous_loss_value > 0: + loss = torch.where(amb_mask, self.ambiguous_loss_value, loss) + + losses.append((loss, view_mask, "pts3d")) + loss_after_masking = loss[view_mask] + if loss_after_masking.numel() > 0: + loss_mean = float(loss_after_masking.mean()) + details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean + # Initialize or update running average directly + avg_key = f"{self_name}_pts3d_avg" + if avg_key not in details: + details[avg_key] = loss_mean + running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1 + else: + valid_views = ( + running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1 + ) + running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views + details[avg_key] += (loss_mean - details[avg_key]) / valid_views + + return Sum(*losses), (details | monitoring) + + +class PointsPlusScaleRegr3D(Criterion, MultiLoss): + """ + Regression Loss for World Frame Pointmaps & Scale. + """ + + def __init__( + self, + criterion, + norm_predictions=True, + norm_mode="avg_dis", + ambiguous_loss_value=0, + loss_in_log=True, + flatten_across_image_only=False, + world_frame_points_loss_weight=1, + scale_loss_weight=1, + ): + """ + Initialize the loss criterion for World Frame Pointmaps & Scale. + The predicted scene representation is always normalized w.r.t. the frame of view0. + Loss is applied between the predicted metric scale and the ground truth metric scale. + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth, pointmaps and scale. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. + scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + """ + super().__init__(criterion) + self.norm_predictions = norm_predictions + self.norm_mode = norm_mode + self.ambiguous_loss_value = ambiguous_loss_value + self.loss_in_log = loss_in_log + self.flatten_across_image_only = flatten_across_image_only + self.world_frame_points_loss_weight = world_frame_points_loss_weight + self.scale_loss_weight = scale_loss_weight + + def get_all_info(self, batch, preds, dist_clip=None): + """ + Function to get all the information needed to compute the loss. + Returns all quantities normalized w.r.t. camera of view0. + """ + n_views = len(batch) + + # Everything is normalized w.r.t. camera of view0 + # Initialize lists to store data for all views + # Ground truth quantities + in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) + no_norm_gt_pts = [] + valid_masks = [] + # Predicted quantities + no_norm_pr_pts = [] + metric_pr_pts_to_compute_scale = [] + + # Get ground truth & prediction info for all views + for i in range(n_views): + # Get the ground truth + no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) + valid_masks.append(batch[i]["valid_mask"].clone()) + + # Get predictions for normalized loss + if "metric_scaling_factor" in preds[i].keys(): + # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans + # This detaches the predicted metric scaling factor from the geometry based loss + curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + else: + curr_view_no_norm_pr_pts = preds[i]["pts3d"] + no_norm_pr_pts.append(curr_view_no_norm_pr_pts) + + # Get the predicted metric scale points + if "metric_scaling_factor" in preds[i].keys(): + # Detach the raw predicted points so that the scale loss is only applied to the scaling factor + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.detach() + * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1) + ) + else: + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.clone() + ) + metric_pr_pts_to_compute_scale.append( + curr_view_metric_pr_pts_to_compute_scale + ) + + if dist_clip is not None: + # Points that are too far-away == invalid + for i in range(n_views): + dis = no_norm_gt_pts[i].norm(dim=-1) + valid_masks[i] = valid_masks[i] & (dis <= dist_clip) + + # Initialize normalized tensors + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + + # Normalize the predicted points if specified + if self.norm_predictions: + pr_normalization_output = normalize_multiple_pointclouds( + no_norm_pr_pts, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_pts_norm = pr_normalization_output[:-1] + + # Normalize the ground truth points + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + gt_pts_norm = gt_normalization_output[:-1] + gt_norm_factor = gt_normalization_output[-1] + + for i in range(n_views): + if self.norm_predictions: + # Assign the normalized predictions + pr_pts[i] = pr_pts_norm[i] + else: + pr_pts[i] = no_norm_pr_pts[i] + # Assign the normalized ground truth quantities + gt_pts[i] = gt_pts_norm[i] + + # Get the mask indicating ground truth metric scale quantities + metric_scale_mask = batch[0]["is_metric_scale"] + valid_gt_norm_factor_mask = ( + gt_norm_factor[:, 0, 0, 0] > 1e-8 + ) # Mask out cases where depth for all views is invalid + valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask + + if valid_metric_scale_mask.any(): + # Compute the scale norm factor using the predicted metric scale points + metric_pr_normalization_output = normalize_multiple_pointclouds( + metric_pr_pts_to_compute_scale, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_metric_norm_factor = metric_pr_normalization_output[-1] + + # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities + gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask] + pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask] + else: + gt_metric_norm_factor = None + pr_metric_norm_factor = None + + # Get ambiguous masks + ambiguous_masks = [] + for i in range(n_views): + ambiguous_masks.append( + (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) + ) + + # Pack into info dicts + gt_info = [] + pred_info = [] + for i in range(n_views): + gt_info.append( + { + "pts3d": gt_pts[i], + } + ) + pred_info.append( + { + "pts3d": pr_pts[i], + } + ) + + return ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) + + def compute_loss(self, batch, preds, **kw): + ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixel as "valid" pixels... + valid_masks = [ + mask | ambig_mask + for mask, ambig_mask in zip(valid_masks, ambiguous_masks) + ] + + pts3d_losses = [] + + for i in range(n_views): + # Get the predicted dense quantities + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] + gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, pts_dim = gt_info[i]["pts3d"].shape + gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) + pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space if specified + if self.loss_in_log: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_pts3d = apply_log_to_norm(pred_pts3d) + + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + pts3d_losses[i], + ) + + # Compute the scale loss + if gt_metric_norm_factor is not None: + if self.loss_in_log: + gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) + pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) + scale_loss = ( + self.criterion( + pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" + ) + * self.scale_loss_weight + ) + else: + scale_loss = None + + # Use helper function to generate loss terms and details + + losses_dict = { + "pts3d": { + "values": pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + "scale": { + "values": scale_loss, + "use_mask": False, + "is_multi_view": False, + }, + } + + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class NormalGMLoss(MultiLoss): + """ + Normal & Gradient Matching Loss for Monocular Depth Training. + """ + + def __init__( + self, + norm_predictions=True, + norm_mode="avg_dis", + apply_normal_and_gm_loss_to_synthetic_data_only=True, + ): + """ + Initialize the loss criterion for Normal & Gradient Matching Loss (currently only valid for 1 view). + Computes: + (1) Normal Loss over the PointMap (naturally will be in local frame) in euclidean coordinates, + (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) + + Args: + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. + If False, apply the normal and gm loss to all data. Default: True. + """ + super().__init__() + self.norm_predictions = norm_predictions + self.norm_mode = norm_mode + self.apply_normal_and_gm_loss_to_synthetic_data_only = ( + apply_normal_and_gm_loss_to_synthetic_data_only + ) + + def get_all_info(self, batch, preds, dist_clip=None): + """ + Function to get all the information needed to compute the loss. + Returns all quantities normalized. + """ + n_views = len(batch) + assert n_views == 1, ( + "Normal & Gradient Matching Loss Class only supports 1 view" + ) + + # Everything is normalized w.r.t. camera of view1 + in_camera1 = closed_form_pose_inverse(batch[0]["camera_pose"]) + + # Initialize lists to store data for all views + no_norm_gt_pts = [] + valid_masks = [] + no_norm_pr_pts = [] + + # Get ground truth & prediction info for all views + for i in range(n_views): + # Get ground truth + no_norm_gt_pts.append(geotrf(in_camera1, batch[i]["pts3d"])) + valid_masks.append(batch[i]["valid_mask"].clone()) + + # Get predictions for normalized loss + if "metric_scaling_factor" in preds[i].keys(): + # Divide by the predicted metric scaling factor to get the raw predicted points + # This detaches the predicted metric scaling factor from the geometry based loss + curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + else: + curr_view_no_norm_pr_pts = preds[i]["pts3d"] + no_norm_pr_pts.append(curr_view_no_norm_pr_pts) + + if dist_clip is not None: + # Points that are too far-away == invalid + for i in range(n_views): + dis = no_norm_gt_pts[i].norm(dim=-1) + valid_masks[i] = valid_masks[i] & (dis <= dist_clip) + + # Initialize normalized tensors + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + + # Normalize the predicted points if specified + if self.norm_predictions: + pr_normalization_output = normalize_multiple_pointclouds( + no_norm_pr_pts, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_pts_norm = pr_normalization_output[:-1] + + # Normalize the ground truth points + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + gt_pts_norm = gt_normalization_output[:-1] + + for i in range(n_views): + if self.norm_predictions: + # Assign the normalized predictions + pr_pts[i] = pr_pts_norm[i] + else: + # Assign the raw predicted points + pr_pts[i] = no_norm_pr_pts[i] + # Assign the normalized ground truth + gt_pts[i] = gt_pts_norm[i] + + return gt_pts, pr_pts, valid_masks + + def compute_loss(self, batch, preds, **kw): + gt_pts, pred_pts, valid_masks = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + assert n_views == 1, ( + "Normal & Gradient Matching Loss Class only supports 1 view" + ) + + normal_losses = [] + gradient_matching_losses = [] + details = {} + running_avg_dict = {} + self_name = type(self).__name__ + + for i in range(n_views): + # Get the local frame points, log space depth_z & valid masks + pred_local_pts3d = pred_pts[i] + pred_depth_z = pred_local_pts3d[..., 2:] + pred_depth_z = apply_log_to_norm(pred_depth_z) + gt_local_pts3d = gt_pts[i] + gt_depth_z = gt_local_pts3d[..., 2:] + gt_depth_z = apply_log_to_norm(gt_depth_z) + valid_mask_for_normal_gm_loss = valid_masks[i].clone() + + # Update the validity mask for normal & gm loss based on the synthetic data mask if required + if self.apply_normal_and_gm_loss_to_synthetic_data_only: + synthetic_mask = batch[i]["is_synthetic"] # (B, ) + synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) + synthetic_mask = synthetic_mask.expand( + -1, pred_depth_z.shape[1], pred_depth_z.shape[2] + ) # (B, H, W) + valid_mask_for_normal_gm_loss = ( + valid_mask_for_normal_gm_loss & synthetic_mask + ) + + # Compute the normal loss + normal_loss = compute_normal_loss( + pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() + ) + normal_losses.append(normal_loss) + + # Compute the gradient matching loss + gradient_matching_loss = compute_gradient_matching_loss( + pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() + ) + gradient_matching_losses.append(gradient_matching_loss) + + # Add loss details if only valid values are present + # Initialize or update running average directly + # Normal loss details + if float(normal_loss) > 0: + details[f"{self_name}_normal_view{i + 1}"] = float(normal_loss) + normal_avg_key = f"{self_name}_normal_avg" + if normal_avg_key not in details: + details[normal_avg_key] = float(normal_losses[i]) + running_avg_dict[f"{self_name}_normal_valid_views"] = 1 + else: + normal_valid_views = ( + running_avg_dict[f"{self_name}_normal_valid_views"] + 1 + ) + running_avg_dict[f"{self_name}_normal_valid_views"] = ( + normal_valid_views + ) + details[normal_avg_key] += ( + float(normal_losses[i]) - details[normal_avg_key] + ) / normal_valid_views + + # Gradient Matching loss details + if float(gradient_matching_loss) > 0: + details[f"{self_name}_gradient_matching_view{i + 1}"] = float( + gradient_matching_loss + ) + # For gradient matching loss + gm_avg_key = f"{self_name}_gradient_matching_avg" + if gm_avg_key not in details: + details[gm_avg_key] = float(gradient_matching_losses[i]) + running_avg_dict[f"{self_name}_gm_valid_views"] = 1 + else: + gm_valid_views = running_avg_dict[f"{self_name}_gm_valid_views"] + 1 + running_avg_dict[f"{self_name}_gm_valid_views"] = gm_valid_views + details[gm_avg_key] += ( + float(gradient_matching_losses[i]) - details[gm_avg_key] + ) / gm_valid_views + + # Put the losses together + loss_terms = [] + for i in range(n_views): + loss_terms.append((normal_losses[i], None, "normal")) + loss_terms.append((gradient_matching_losses[i], None, "gradient_matching")) + losses = Sum(*loss_terms) + + return losses, details + + +class FactoredGeometryRegr3D(Criterion, MultiLoss): + """ + Regression Loss for Factored Geometry. + """ + + def __init__( + self, + criterion, + norm_mode="?avg_dis", + gt_scale=False, + ambiguous_loss_value=0, + max_metric_scale=False, + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + cam_frame_points_loss_weight=1, + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + compute_pairwise_relative_pose_loss=False, + convert_predictions_to_view0_frame=False, + compute_world_frame_points_loss=True, + world_frame_points_loss_weight=1, + ): + """ + Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose), + and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps. + If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order: + (1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans. + Else, the pixel-level losses are returned in the following order: + (1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans. + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis". + If prefixed with "?", normalization is only applied to non-metric scale data. + gt_scale (bool): If True, enforce predictions to have the same scale as ground truth. + If False, both GT and predictions are normalized independently. Default: False. + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this + value, it will be treated as non-metric. Default: False (no limit). + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth and pointmaps. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the + exhaustive pairwise relative poses. Default: False. + convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. + Use this if the predictions are not already in the view0 frame. Default: False. + compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. + world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. + """ + super().__init__(criterion) + if norm_mode.startswith("?"): + # Do no norm pts from metric scale datasets + self.norm_all = False + self.norm_mode = norm_mode[1:] + else: + self.norm_all = True + self.norm_mode = norm_mode + self.gt_scale = gt_scale + self.ambiguous_loss_value = ambiguous_loss_value + self.max_metric_scale = max_metric_scale + self.loss_in_log = loss_in_log + self.flatten_across_image_only = flatten_across_image_only + self.depth_type_for_loss = depth_type_for_loss + assert self.depth_type_for_loss in [ + "depth_along_ray", + "depth_z", + ], "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" + self.cam_frame_points_loss_weight = cam_frame_points_loss_weight + self.depth_loss_weight = depth_loss_weight + self.ray_directions_loss_weight = ray_directions_loss_weight + self.pose_quats_loss_weight = pose_quats_loss_weight + self.pose_trans_loss_weight = pose_trans_loss_weight + self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss + self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame + self.compute_world_frame_points_loss = compute_world_frame_points_loss + self.world_frame_points_loss_weight = world_frame_points_loss_weight + + def get_all_info(self, batch, preds, dist_clip=None): + """ + Function to get all the information needed to compute the loss. + Returns all quantities normalized w.r.t. camera of view0. + """ + n_views = len(batch) + + # Everything is normalized w.r.t. camera of view0 + # Initialize lists to store data for all views + # Ground truth quantities + in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) + no_norm_gt_pts = [] + no_norm_gt_pts_cam = [] + no_norm_gt_depth = [] + no_norm_gt_pose_trans = [] + valid_masks = [] + gt_ray_directions = [] + gt_pose_quats = [] + # Predicted quantities + if self.convert_predictions_to_view0_frame: + # Get the camera transform to convert quantities to view0 frame + pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze( + 0 + ) + batch_size = preds[0]["cam_quats"].shape[0] + pred_camera0 = pred_camera0.repeat(batch_size, 1, 1) + pred_camera0_rot = quaternion_to_rotation_matrix( + preds[0]["cam_quats"].clone() + ) + pred_camera0[..., :3, :3] = pred_camera0_rot + pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone() + pred_in_camera0 = closed_form_pose_inverse(pred_camera0) + no_norm_pr_pts = [] + no_norm_pr_pts_cam = [] + no_norm_pr_depth = [] + no_norm_pr_pose_trans = [] + pr_ray_directions = [] + pr_pose_quats = [] + + # Get ground truth & prediction info for all views + for i in range(n_views): + # Get ground truth + no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) + valid_masks.append(batch[i]["valid_mask"].clone()) + no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"]) + gt_ray_directions.append(batch[i]["ray_directions_cam"]) + if self.depth_type_for_loss == "depth_along_ray": + no_norm_gt_depth.append(batch[i]["depth_along_ray"]) + elif self.depth_type_for_loss == "depth_z": + no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:]) + if i == 0: + # For view0, initialize identity pose + gt_pose_quats.append( + torch.tensor( + [0, 0, 0, 1], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + no_norm_gt_pose_trans.append( + torch.tensor( + [0, 0, 0], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + else: + # For other views, transform pose to view0's frame + gt_pose_quats_world = batch[i]["camera_pose_quats"] + no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"] + gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = ( + transform_pose_using_quats_and_trans_2_to_1( + batch[0]["camera_pose_quats"], + batch[0]["camera_pose_trans"], + gt_pose_quats_world, + no_norm_gt_pose_trans_world, + ) + ) + gt_pose_quats.append(gt_pose_quats_in_view0) + no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0) + + # Get the local predictions + no_norm_pr_pts_cam.append(preds[i]["pts3d_cam"]) + pr_ray_directions.append(preds[i]["ray_directions"]) + if self.depth_type_for_loss == "depth_along_ray": + no_norm_pr_depth.append(preds[i]["depth_along_ray"]) + elif self.depth_type_for_loss == "depth_z": + no_norm_pr_depth.append(preds[i]["pts3d_cam"][..., 2:]) + + # Get the predicted global predictions in view0's frame + if self.convert_predictions_to_view0_frame: + # Convert predictions to view0 frame + pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"]) + pr_pose_quats_in_view0, pr_pose_trans_in_view0 = ( + transform_pose_using_quats_and_trans_2_to_1( + preds[0]["cam_quats"], + preds[0]["cam_trans"], + preds[i]["cam_quats"], + preds[i]["cam_trans"], + ) + ) + no_norm_pr_pts.append(pr_pts3d_in_view0) + no_norm_pr_pose_trans.append(pr_pose_trans_in_view0) + pr_pose_quats.append(pr_pose_quats_in_view0) + else: + # Predictions are already in view0 frame + no_norm_pr_pts.append(preds[i]["pts3d"]) + no_norm_pr_pose_trans.append(preds[i]["cam_trans"]) + pr_pose_quats.append(preds[i]["cam_quats"]) + + if dist_clip is not None: + # Points that are too far-away == invalid + for i in range(n_views): + dis = no_norm_gt_pts[i].norm(dim=-1) + valid_masks[i] = valid_masks[i] & (dis <= dist_clip) + + # Handle metric scale + if not self.norm_all: + if self.max_metric_scale: + B = valid_masks[0].shape[0] + dists_to_cam1 = [] + for i in range(n_views): + dists_to_cam1.append( + torch.where( + valid_masks[i], torch.norm(no_norm_gt_pts[i], dim=-1), 0 + ).view(B, -1) + ) + + batch[0]["is_metric_scale"] = batch[0]["is_metric_scale"] + for dist in dists_to_cam1: + batch[0]["is_metric_scale"] &= ( + dist.max(dim=-1).values < self.max_metric_scale + ) + + for i in range(1, n_views): + batch[i]["is_metric_scale"] = batch[0]["is_metric_scale"] + + non_metric_scale_mask = ~batch[0]["is_metric_scale"] + else: + non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"]) + + # Initialize normalized tensors + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam] + gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth] + gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans] + + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam] + pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth] + pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans] + + # Normalize points + if self.norm_mode and non_metric_scale_mask.any(): + pr_normalization_output = normalize_multiple_pointclouds( + [pts[non_metric_scale_mask] for pts in no_norm_pr_pts], + [mask[non_metric_scale_mask] for mask in valid_masks], + self.norm_mode, + ret_factor=True, + ) + pr_pts_norm = pr_normalization_output[:-1] + pr_norm_factor = pr_normalization_output[-1] + + for i in range(n_views): + pr_pts[i][non_metric_scale_mask] = pr_pts_norm[i] + pr_pts_cam[i][non_metric_scale_mask] = ( + no_norm_pr_pts_cam[i][non_metric_scale_mask] / pr_norm_factor + ) + pr_depth[i][non_metric_scale_mask] = ( + no_norm_pr_depth[i][non_metric_scale_mask] / pr_norm_factor + ) + pr_pose_trans[i][non_metric_scale_mask] = ( + no_norm_pr_pose_trans[i][non_metric_scale_mask] + / pr_norm_factor[:, :, 0, 0] + ) + + elif non_metric_scale_mask.any(): + for i in range(n_views): + pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][ + non_metric_scale_mask + ] + pr_pts_cam[i][non_metric_scale_mask] = no_norm_pr_pts_cam[i][ + non_metric_scale_mask + ] + pr_depth[i][non_metric_scale_mask] = no_norm_pr_depth[i][ + non_metric_scale_mask + ] + pr_pose_trans[i][non_metric_scale_mask] = no_norm_pr_pose_trans[i][ + non_metric_scale_mask + ] + + if self.norm_mode and not self.gt_scale: + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + gt_pts_norm = gt_normalization_output[:-1] + norm_factor = gt_normalization_output[-1] + + for i in range(n_views): + gt_pts[i] = gt_pts_norm[i] + gt_pts_cam[i] = no_norm_gt_pts_cam[i] / norm_factor + gt_depth[i] = no_norm_gt_depth[i] / norm_factor + gt_pose_trans[i] = no_norm_gt_pose_trans[i] / norm_factor[:, :, 0, 0] + + pr_pts[i][~non_metric_scale_mask] = ( + no_norm_pr_pts[i][~non_metric_scale_mask] + / norm_factor[~non_metric_scale_mask] + ) + pr_pts_cam[i][~non_metric_scale_mask] = ( + no_norm_pr_pts_cam[i][~non_metric_scale_mask] + / norm_factor[~non_metric_scale_mask] + ) + pr_depth[i][~non_metric_scale_mask] = ( + no_norm_pr_depth[i][~non_metric_scale_mask] + / norm_factor[~non_metric_scale_mask] + ) + pr_pose_trans[i][~non_metric_scale_mask] = ( + no_norm_pr_pose_trans[i][~non_metric_scale_mask] + / norm_factor[~non_metric_scale_mask][:, :, 0, 0] + ) + + elif ~non_metric_scale_mask.any(): + for i in range(n_views): + gt_pts[i] = no_norm_gt_pts[i] + gt_pts_cam[i] = no_norm_gt_pts_cam[i] + gt_depth[i] = no_norm_gt_depth[i] + gt_pose_trans[i] = no_norm_gt_pose_trans[i] + pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][ + ~non_metric_scale_mask + ] + pr_pts_cam[i][~non_metric_scale_mask] = no_norm_pr_pts_cam[i][ + ~non_metric_scale_mask + ] + pr_depth[i][~non_metric_scale_mask] = no_norm_pr_depth[i][ + ~non_metric_scale_mask + ] + pr_pose_trans[i][~non_metric_scale_mask] = no_norm_pr_pose_trans[i][ + ~non_metric_scale_mask + ] + else: + for i in range(n_views): + gt_pts[i] = no_norm_gt_pts[i] + gt_pts_cam[i] = no_norm_gt_pts_cam[i] + gt_depth[i] = no_norm_gt_depth[i] + gt_pose_trans[i] = no_norm_gt_pose_trans[i] + + # Get ambiguous masks + ambiguous_masks = [] + for i in range(n_views): + ambiguous_masks.append( + (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) + ) + + # Pack into info dicts + gt_info = [] + pred_info = [] + for i in range(n_views): + gt_info.append( + { + "ray_directions": gt_ray_directions[i], + self.depth_type_for_loss: gt_depth[i], + "pose_trans": gt_pose_trans[i], + "pose_quats": gt_pose_quats[i], + "pts3d": gt_pts[i], + "pts3d_cam": gt_pts_cam[i], + } + ) + pred_info.append( + { + "ray_directions": pr_ray_directions[i], + self.depth_type_for_loss: pr_depth[i], + "pose_trans": pr_pose_trans[i], + "pose_quats": pr_pose_quats[i], + "pts3d": pr_pts[i], + "pts3d_cam": pr_pts_cam[i], + } + ) + + return gt_info, pred_info, valid_masks, ambiguous_masks + + def compute_loss(self, batch, preds, **kw): + gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info( + batch, preds, **kw + ) + n_views = len(batch) + + # Mask out samples in the batch where the gt depth validity mask is entirely zero + valid_norm_factor_masks = [ + mask.sum(dim=(1, 2)) > 0 for mask in valid_masks + ] # List of (B,) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixel as "valid" pixels... + valid_masks = [ + mask | ambig_mask + for mask, ambig_mask in zip(valid_masks, ambiguous_masks) + ] + + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + cam_pts3d_losses = [] + if self.compute_world_frame_points_loss: + pts3d_losses = [] + + for i in range(n_views): + # Get the predicted dense quantities + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_ray_directions = pred_info[i]["ray_directions"] + gt_ray_directions = gt_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] + gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] + pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] + gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] + if self.compute_world_frame_points_loss: + pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] + gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape + gt_ray_directions = gt_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + pred_ray_directions = pred_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] + gt_depth = gt_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + pred_depth = pred_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] + gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) + pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( + batch_size, -1, cam_pts_dim + ) + if self.compute_world_frame_points_loss: + pts_dim = gt_info[i]["pts3d"].shape[-1] + gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) + pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space for depth if specified + if self.loss_in_log: + gt_depth = apply_log_to_norm(gt_depth) + pred_depth = apply_log_to_norm(pred_depth) + gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) + pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) + if self.compute_world_frame_points_loss: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_pts3d = apply_log_to_norm(pred_pts3d) + + if self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions, gt_ray_directions, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute camera frame point loss + cam_pts3d_loss = self.criterion( + pred_cam_pts3d, gt_cam_pts3d, factor="points" + ) + cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight + cam_pts3d_losses.append(cam_pts3d_loss) + + if self.compute_world_frame_points_loss: + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + depth_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + depth_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + pts3d_losses[i], + ) + + # Use helper function to generate loss terms and details + if self.compute_world_frame_points_loss: + losses_dict = { + "pts3d": { + "values": pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + } + else: + losses_dict = {} + losses_dict.update( + { + "cam_pts3d": { + "values": cam_pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": False, + "is_multi_view": True, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D): + """ + Regression, Normals & Gradient Matching Loss for Factored Geometry. + """ + + def __init__( + self, + criterion, + norm_mode="?avg_dis", + gt_scale=False, + ambiguous_loss_value=0, + max_metric_scale=False, + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + cam_frame_points_loss_weight=1, + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + compute_pairwise_relative_pose_loss=False, + convert_predictions_to_view0_frame=False, + compute_world_frame_points_loss=True, + world_frame_points_loss_weight=1, + apply_normal_and_gm_loss_to_synthetic_data_only=True, + normal_loss_weight=1, + gm_loss_weight=1, + ): + """ + Initialize the loss criterion for Factored Geometry (see parent class for details). + Additionally computes: + (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates, + (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_mode (str): Normalization mode for scene representation. Default: "avg_dis". + If prefixed with "?", normalization is only applied to non-metric scale data. + gt_scale (bool): If True, enforce predictions to have the same scale as ground truth. + If False, both GT and predictions are normalized independently. Default: False. + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this + value, it will be treated as non-metric. Default: False (no limit). + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth and pointmaps. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the + exhaustive pairwise relative poses. Default: False. + convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. + Use this if the predictions are not already in the view0 frame. Default: False. + compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. + world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. + apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. + If False, apply the normal and gm loss to all data. Default: True. + normal_loss_weight (float): Weight to use for the normal loss. Default: 1. + gm_loss_weight (float): Weight to use for the gm loss. Default: 1. + """ + super().__init__( + criterion=criterion, + norm_mode=norm_mode, + gt_scale=gt_scale, + ambiguous_loss_value=ambiguous_loss_value, + max_metric_scale=max_metric_scale, + loss_in_log=loss_in_log, + flatten_across_image_only=flatten_across_image_only, + depth_type_for_loss=depth_type_for_loss, + cam_frame_points_loss_weight=cam_frame_points_loss_weight, + depth_loss_weight=depth_loss_weight, + ray_directions_loss_weight=ray_directions_loss_weight, + pose_quats_loss_weight=pose_quats_loss_weight, + pose_trans_loss_weight=pose_trans_loss_weight, + compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss, + convert_predictions_to_view0_frame=convert_predictions_to_view0_frame, + compute_world_frame_points_loss=compute_world_frame_points_loss, + world_frame_points_loss_weight=world_frame_points_loss_weight, + ) + self.apply_normal_and_gm_loss_to_synthetic_data_only = ( + apply_normal_and_gm_loss_to_synthetic_data_only + ) + self.normal_loss_weight = normal_loss_weight + self.gm_loss_weight = gm_loss_weight + + def compute_loss(self, batch, preds, **kw): + gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info( + batch, preds, **kw + ) + n_views = len(batch) + + # Mask out samples in the batch where the gt depth validity mask is entirely zero + valid_norm_factor_masks = [ + mask.sum(dim=(1, 2)) > 0 for mask in valid_masks + ] # List of (B,) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixel as "valid" pixels... + valid_masks = [ + mask | ambig_mask + for mask, ambig_mask in zip(valid_masks, ambiguous_masks) + ] + + normal_losses = [] + gradient_matching_losses = [] + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + cam_pts3d_losses = [] + if self.compute_world_frame_points_loss: + pts3d_losses = [] + + for i in range(n_views): + # Get the camera frame points, log space depth_z & valid masks + pred_local_pts3d = pred_info[i]["pts3d_cam"] + pred_depth_z = pred_local_pts3d[..., 2:] + pred_depth_z = apply_log_to_norm(pred_depth_z) + gt_local_pts3d = gt_info[i]["pts3d_cam"] + gt_depth_z = gt_local_pts3d[..., 2:] + gt_depth_z = apply_log_to_norm(gt_depth_z) + valid_mask_for_normal_gm_loss = valid_masks[i].clone() + + # Update the validity mask for normal & gm loss based on the synthetic data mask if required + if self.apply_normal_and_gm_loss_to_synthetic_data_only: + synthetic_mask = batch[i]["is_synthetic"] # (B, ) + synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) + synthetic_mask = synthetic_mask.expand( + -1, pred_depth_z.shape[1], pred_depth_z.shape[2] + ) # (B, H, W) + valid_mask_for_normal_gm_loss = ( + valid_mask_for_normal_gm_loss & synthetic_mask + ) + + # Compute the normal loss + normal_loss = compute_normal_loss( + pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() + ) + normal_loss = normal_loss * self.normal_loss_weight + normal_losses.append(normal_loss) + + # Compute the gradient matching loss + gradient_matching_loss = compute_gradient_matching_loss( + pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() + ) + gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight + gradient_matching_losses.append(gradient_matching_loss) + + # Get the predicted dense quantities + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_ray_directions = pred_info[i]["ray_directions"] + gt_ray_directions = gt_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] + gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] + pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] + gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] + if self.compute_world_frame_points_loss: + pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] + gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape + gt_ray_directions = gt_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + pred_ray_directions = pred_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] + gt_depth = gt_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + pred_depth = pred_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] + gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) + pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( + batch_size, -1, cam_pts_dim + ) + if self.compute_world_frame_points_loss: + pts_dim = gt_info[i]["pts3d"].shape[-1] + gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) + pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space for depth if specified + if self.loss_in_log: + gt_depth = apply_log_to_norm(gt_depth) + pred_depth = apply_log_to_norm(pred_depth) + gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) + pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) + if self.compute_world_frame_points_loss: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_pts3d = apply_log_to_norm(pred_pts3d) + + if self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions, gt_ray_directions, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute camera frame point loss + cam_pts3d_loss = self.criterion( + pred_cam_pts3d, gt_cam_pts3d, factor="points" + ) + cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight + cam_pts3d_losses.append(cam_pts3d_loss) + + if self.compute_world_frame_points_loss: + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + depth_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + depth_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + pts3d_losses[i], + ) + + # Use helper function to generate loss terms and details + if self.compute_world_frame_points_loss: + losses_dict = { + "pts3d": { + "values": pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + } + else: + losses_dict = {} + losses_dict.update( + { + "cam_pts3d": { + "values": cam_pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": False, + "is_multi_view": True, + }, + "normal": { + "values": normal_losses, + "use_mask": False, + "is_multi_view": True, + }, + "gradient_matching": { + "values": gradient_matching_losses, + "use_mask": False, + "is_multi_view": True, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class FactoredGeometryScaleRegr3D(Criterion, MultiLoss): + """ + Regression Loss for Factored Geometry & Scale. + """ + + def __init__( + self, + criterion, + norm_predictions=True, + norm_mode="avg_dis", + ambiguous_loss_value=0, + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + cam_frame_points_loss_weight=1, + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + scale_loss_weight=1, + compute_pairwise_relative_pose_loss=False, + convert_predictions_to_view0_frame=False, + compute_world_frame_points_loss=True, + world_frame_points_loss_weight=1, + ): + """ + Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose), Scale + and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps. + If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order: + (1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans, (7) scale. + Else, the pixel-level losses are returned in the following order: + (1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans, (6) scale. + The predicited scene representation is always normalized w.r.t. the frame of view0. + Loss is applied between the predicted metric scale and the ground truth metric scale. + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth, pointmaps and scale. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the + exhaustive pairwise relative poses. Default: False. + convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. + Use this if the predictions are not already in the view0 frame. Default: False. + compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. + world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. + """ + super().__init__(criterion) + self.norm_predictions = norm_predictions + self.norm_mode = norm_mode + self.ambiguous_loss_value = ambiguous_loss_value + self.loss_in_log = loss_in_log + self.flatten_across_image_only = flatten_across_image_only + self.depth_type_for_loss = depth_type_for_loss + assert self.depth_type_for_loss in [ + "depth_along_ray", + "depth_z", + ], "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" + self.cam_frame_points_loss_weight = cam_frame_points_loss_weight + self.depth_loss_weight = depth_loss_weight + self.ray_directions_loss_weight = ray_directions_loss_weight + self.pose_quats_loss_weight = pose_quats_loss_weight + self.pose_trans_loss_weight = pose_trans_loss_weight + self.scale_loss_weight = scale_loss_weight + self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss + self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame + self.compute_world_frame_points_loss = compute_world_frame_points_loss + self.world_frame_points_loss_weight = world_frame_points_loss_weight + + def get_all_info(self, batch, preds, dist_clip=None): + """ + Function to get all the information needed to compute the loss. + Returns all quantities normalized w.r.t. camera of view0. + """ + n_views = len(batch) + + # Everything is normalized w.r.t. camera of view0 + # Intialize lists to store data for all views + # Ground truth quantities + in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) + no_norm_gt_pts = [] + no_norm_gt_pts_cam = [] + no_norm_gt_depth = [] + no_norm_gt_pose_trans = [] + valid_masks = [] + gt_ray_directions = [] + gt_pose_quats = [] + # Predicted quantities + if self.convert_predictions_to_view0_frame: + # Get the camera transform to convert quantities to view0 frame + pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze( + 0 + ) + batch_size = preds[0]["cam_quats"].shape[0] + pred_camera0 = pred_camera0.repeat(batch_size, 1, 1) + pred_camera0_rot = quaternion_to_rotation_matrix( + preds[0]["cam_quats"].clone() + ) + pred_camera0[..., :3, :3] = pred_camera0_rot + pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone() + pred_in_camera0 = closed_form_pose_inverse(pred_camera0) + no_norm_pr_pts = [] + no_norm_pr_pts_cam = [] + no_norm_pr_depth = [] + no_norm_pr_pose_trans = [] + pr_ray_directions = [] + pr_pose_quats = [] + metric_pr_pts_to_compute_scale = [] + + # Get ground truth & prediction info for all views + for i in range(n_views): + # Get the ground truth + no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) + valid_masks.append(batch[i]["valid_mask"].clone()) + no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"]) + gt_ray_directions.append(batch[i]["ray_directions_cam"]) + if self.depth_type_for_loss == "depth_along_ray": + no_norm_gt_depth.append(batch[i]["depth_along_ray"]) + elif self.depth_type_for_loss == "depth_z": + no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:]) + if i == 0: + # For view0, initialize identity pose + gt_pose_quats.append( + torch.tensor( + [0, 0, 0, 1], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + no_norm_gt_pose_trans.append( + torch.tensor( + [0, 0, 0], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + else: + # For other views, transform pose to view0's frame + gt_pose_quats_world = batch[i]["camera_pose_quats"] + no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"] + gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = ( + transform_pose_using_quats_and_trans_2_to_1( + batch[0]["camera_pose_quats"], + batch[0]["camera_pose_trans"], + gt_pose_quats_world, + no_norm_gt_pose_trans_world, + ) + ) + gt_pose_quats.append(gt_pose_quats_in_view0) + no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0) + + # Get the global predictions in view0's frame + if self.convert_predictions_to_view0_frame: + # Convert predictions to view0 frame + pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"]) + pr_pose_quats_in_view0, pr_pose_trans_in_view0 = ( + transform_pose_using_quats_and_trans_2_to_1( + preds[0]["cam_quats"], + preds[0]["cam_trans"], + preds[i]["cam_quats"], + preds[i]["cam_trans"], + ) + ) + else: + # Predictions are already in view0 frame + pr_pts3d_in_view0 = preds[i]["pts3d"] + pr_pose_trans_in_view0 = preds[i]["cam_trans"] + pr_pose_quats_in_view0 = preds[i]["cam_quats"] + + # Get predictions for normalized loss + if self.depth_type_for_loss == "depth_along_ray": + curr_view_no_norm_depth = preds[i]["depth_along_ray"] + elif self.depth_type_for_loss == "depth_z": + curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:] + if "metric_scaling_factor" in preds[i].keys(): + # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans + # This detaches the predicted metric scaling factor from the geometry based loss + curr_view_no_norm_pr_pts = pr_pts3d_in_view0 / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_pr_pose_trans = ( + pr_pose_trans_in_view0 / preds[i]["metric_scaling_factor"] + ) + else: + curr_view_no_norm_pr_pts = pr_pts3d_in_view0 + curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] + curr_view_no_norm_depth = curr_view_no_norm_depth + curr_view_no_norm_pr_pose_trans = pr_pose_trans_in_view0 + no_norm_pr_pts.append(curr_view_no_norm_pr_pts) + no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam) + no_norm_pr_depth.append(curr_view_no_norm_depth) + no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans) + pr_ray_directions.append(preds[i]["ray_directions"]) + pr_pose_quats.append(pr_pose_quats_in_view0) + + # Get the predicted metric scale points + if "metric_scaling_factor" in preds[i].keys(): + # Detach the raw predicted points so that the scale loss is only applied to the scaling factor + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.detach() + * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1) + ) + else: + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.clone() + ) + metric_pr_pts_to_compute_scale.append( + curr_view_metric_pr_pts_to_compute_scale + ) + + if dist_clip is not None: + # Points that are too far-away == invalid + for i in range(n_views): + dis = no_norm_gt_pts[i].norm(dim=-1) + valid_masks[i] = valid_masks[i] & (dis <= dist_clip) + + # Initialize normalized tensors + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam] + gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth] + gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans] + + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam] + pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth] + pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans] + + # Normalize the predicted points if specified + if self.norm_predictions: + pr_normalization_output = normalize_multiple_pointclouds( + no_norm_pr_pts, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_pts_norm = pr_normalization_output[:-1] + pr_norm_factor = pr_normalization_output[-1] + + # Normalize the ground truth points + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + gt_pts_norm = gt_normalization_output[:-1] + gt_norm_factor = gt_normalization_output[-1] + + for i in range(n_views): + if self.norm_predictions: + # Assign the normalized predictions + pr_pts[i] = pr_pts_norm[i] + pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor + pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor + pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0] + else: + pr_pts[i] = no_norm_pr_pts[i] + pr_pts_cam[i] = no_norm_pr_pts_cam[i] + pr_depth[i] = no_norm_pr_depth[i] + pr_pose_trans[i] = no_norm_pr_pose_trans[i] + # Assign the normalized ground truth quantities + gt_pts[i] = gt_pts_norm[i] + gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor + gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor + gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0] + + # Get the mask indicating ground truth metric scale quantities + metric_scale_mask = batch[0]["is_metric_scale"] + valid_gt_norm_factor_mask = ( + gt_norm_factor[:, 0, 0, 0] > 1e-8 + ) # Mask out cases where depth for all views is invalid + valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask + + if valid_metric_scale_mask.any(): + # Compute the scale norm factor using the predicted metric scale points + metric_pr_normalization_output = normalize_multiple_pointclouds( + metric_pr_pts_to_compute_scale, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_metric_norm_factor = metric_pr_normalization_output[-1] + + # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities + gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask] + pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask] + else: + gt_metric_norm_factor = None + pr_metric_norm_factor = None + + # Get ambiguous masks + ambiguous_masks = [] + for i in range(n_views): + ambiguous_masks.append( + (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) + ) + + # Pack into info dicts + gt_info = [] + pred_info = [] + for i in range(n_views): + gt_info.append( + { + "ray_directions": gt_ray_directions[i], + self.depth_type_for_loss: gt_depth[i], + "pose_trans": gt_pose_trans[i], + "pose_quats": gt_pose_quats[i], + "pts3d": gt_pts[i], + "pts3d_cam": gt_pts_cam[i], + } + ) + pred_info.append( + { + "ray_directions": pr_ray_directions[i], + self.depth_type_for_loss: pr_depth[i], + "pose_trans": pr_pose_trans[i], + "pose_quats": pr_pose_quats[i], + "pts3d": pr_pts[i], + "pts3d_cam": pr_pts_cam[i], + } + ) + + return ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) + + def compute_loss(self, batch, preds, **kw): + ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + + # Mask out samples in the batch where the gt depth validity mask is entirely zero + valid_norm_factor_masks = [ + mask.sum(dim=(1, 2)) > 0 for mask in valid_masks + ] # List of (B,) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixel as "valid" pixels... + valid_masks = [ + mask | ambig_mask + for mask, ambig_mask in zip(valid_masks, ambiguous_masks) + ] + + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + cam_pts3d_losses = [] + if self.compute_world_frame_points_loss: + pts3d_losses = [] + + for i in range(n_views): + # Get the predicted dense quantities + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_ray_directions = pred_info[i]["ray_directions"] + gt_ray_directions = gt_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] + gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] + pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] + gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] + if self.compute_world_frame_points_loss: + pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] + gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape + gt_ray_directions = gt_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + pred_ray_directions = pred_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] + gt_depth = gt_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + pred_depth = pred_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] + gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) + pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( + batch_size, -1, cam_pts_dim + ) + if self.compute_world_frame_points_loss: + pts_dim = gt_info[i]["pts3d"].shape[-1] + gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) + pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space for depth if specified + if self.loss_in_log: + gt_depth = apply_log_to_norm(gt_depth) + pred_depth = apply_log_to_norm(pred_depth) + gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) + pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) + if self.compute_world_frame_points_loss: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_pts3d = apply_log_to_norm(pred_pts3d) + + if self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions, gt_ray_directions, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute camera frame point loss + cam_pts3d_loss = self.criterion( + pred_cam_pts3d, gt_cam_pts3d, factor="points" + ) + cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight + cam_pts3d_losses.append(cam_pts3d_loss) + + if self.compute_world_frame_points_loss: + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + depth_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + depth_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + pts3d_losses[i], + ) + + # Compute the scale loss + if gt_metric_norm_factor is not None: + if self.loss_in_log: + gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) + pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) + scale_loss = ( + self.criterion( + pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" + ) + * self.scale_loss_weight + ) + else: + scale_loss = None + + # Use helper function to generate loss terms and details + if self.compute_world_frame_points_loss: + losses_dict = { + "pts3d": { + "values": pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + } + else: + losses_dict = {} + losses_dict.update( + { + "cam_pts3d": { + "values": cam_pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": False, + "is_multi_view": True, + }, + "scale": { + "values": scale_loss, + "use_mask": False, + "is_multi_view": False, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D): + """ + Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale. + """ + + def __init__( + self, + criterion, + norm_predictions=True, + norm_mode="avg_dis", + ambiguous_loss_value=0, + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + cam_frame_points_loss_weight=1, + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + scale_loss_weight=1, + compute_pairwise_relative_pose_loss=False, + convert_predictions_to_view0_frame=False, + compute_world_frame_points_loss=True, + world_frame_points_loss_weight=1, + apply_normal_and_gm_loss_to_synthetic_data_only=True, + normal_loss_weight=1, + gm_loss_weight=1, + ): + """ + Initialize the loss criterion for Ray Directions, Depth, Pose, Pointmaps & Scale. + Additionally computes: + (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates, + (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. + If 0, ambiguous pixels are ignored. Default: 0. + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth, pointmaps and scale. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the + exhaustive pairwise relative poses. Default: False. + convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. + Use this if the predictions are not already in the view0 frame. Default: False. + compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. + world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. + apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. + If False, apply the normal and gm loss to all data. Default: True. + normal_loss_weight (float): Weight to use for the normal loss. Default: 1. + gm_loss_weight (float): Weight to use for the gm loss. Default: 1. + """ + super().__init__( + criterion=criterion, + norm_predictions=norm_predictions, + norm_mode=norm_mode, + ambiguous_loss_value=ambiguous_loss_value, + loss_in_log=loss_in_log, + flatten_across_image_only=flatten_across_image_only, + depth_type_for_loss=depth_type_for_loss, + cam_frame_points_loss_weight=cam_frame_points_loss_weight, + depth_loss_weight=depth_loss_weight, + ray_directions_loss_weight=ray_directions_loss_weight, + pose_quats_loss_weight=pose_quats_loss_weight, + pose_trans_loss_weight=pose_trans_loss_weight, + scale_loss_weight=scale_loss_weight, + compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss, + convert_predictions_to_view0_frame=convert_predictions_to_view0_frame, + compute_world_frame_points_loss=compute_world_frame_points_loss, + world_frame_points_loss_weight=world_frame_points_loss_weight, + ) + self.apply_normal_and_gm_loss_to_synthetic_data_only = ( + apply_normal_and_gm_loss_to_synthetic_data_only + ) + self.normal_loss_weight = normal_loss_weight + self.gm_loss_weight = gm_loss_weight + + def compute_loss(self, batch, preds, **kw): + ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + + # Mask out samples in the batch where the gt depth validity mask is entirely zero + valid_norm_factor_masks = [ + mask.sum(dim=(1, 2)) > 0 for mask in valid_masks + ] # List of (B,) + + if self.ambiguous_loss_value > 0: + assert self.criterion.reduction == "none", ( + "ambiguous_loss_value should be 0 if no conf loss" + ) + # Add the ambiguous pixel as "valid" pixels... + valid_masks = [ + mask | ambig_mask + for mask, ambig_mask in zip(valid_masks, ambiguous_masks) + ] + + normal_losses = [] + gradient_matching_losses = [] + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + cam_pts3d_losses = [] + if self.compute_world_frame_points_loss: + pts3d_losses = [] + + for i in range(n_views): + # Get the camera frame points, log space depth_z & valid masks + pred_local_pts3d = pred_info[i]["pts3d_cam"] + pred_depth_z = pred_local_pts3d[..., 2:] + pred_depth_z = apply_log_to_norm(pred_depth_z) + gt_local_pts3d = gt_info[i]["pts3d_cam"] + gt_depth_z = gt_local_pts3d[..., 2:] + gt_depth_z = apply_log_to_norm(gt_depth_z) + valid_mask_for_normal_gm_loss = valid_masks[i].clone() + + # Update the validity mask for normal & gm loss based on the synthetic data mask if required + if self.apply_normal_and_gm_loss_to_synthetic_data_only: + synthetic_mask = batch[i]["is_synthetic"] # (B, ) + synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) + synthetic_mask = synthetic_mask.expand( + -1, pred_depth_z.shape[1], pred_depth_z.shape[2] + ) # (B, H, W) + valid_mask_for_normal_gm_loss = ( + valid_mask_for_normal_gm_loss & synthetic_mask + ) + + # Compute the normal loss + normal_loss = compute_normal_loss( + pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() + ) + normal_loss = normal_loss * self.normal_loss_weight + normal_losses.append(normal_loss) + + # Compute the gradient matching loss + gradient_matching_loss = compute_gradient_matching_loss( + pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() + ) + gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight + gradient_matching_losses.append(gradient_matching_loss) + + # Get the predicted dense quantities + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks and compute the metrics + pred_ray_directions = pred_info[i]["ray_directions"] + gt_ray_directions = gt_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] + gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] + pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] + gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] + if self.compute_world_frame_points_loss: + pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] + gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W and compute the metrics + batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape + gt_ray_directions = gt_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + pred_ray_directions = pred_info[i]["ray_directions"].view( + batch_size, -1, direction_dim + ) + depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] + gt_depth = gt_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + pred_depth = pred_info[i][self.depth_type_for_loss].view( + batch_size, -1, depth_dim + ) + cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] + gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) + pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( + batch_size, -1, cam_pts_dim + ) + if self.compute_world_frame_points_loss: + pts_dim = gt_info[i]["pts3d"].shape[-1] + gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) + pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space for depth if specified + if self.loss_in_log: + gt_depth = apply_log_to_norm(gt_depth) + pred_depth = apply_log_to_norm(pred_depth) + gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) + pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) + if self.compute_world_frame_points_loss: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_pts3d = apply_log_to_norm(pred_pts3d) + + if self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions, gt_ray_directions, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute camera frame point loss + cam_pts3d_loss = self.criterion( + pred_cam_pts3d, gt_cam_pts3d, factor="points" + ) + cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight + cam_pts3d_losses.append(cam_pts3d_loss) + + if self.compute_world_frame_points_loss: + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + depth_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + depth_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + pts3d_losses[i], + ) + + # Compute the scale loss + if gt_metric_norm_factor is not None: + if self.loss_in_log: + gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) + pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) + scale_loss = ( + self.criterion( + pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" + ) + * self.scale_loss_weight + ) + else: + scale_loss = None + + # Use helper function to generate loss terms and details + if self.compute_world_frame_points_loss: + losses_dict = { + "pts3d": { + "values": pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + } + else: + losses_dict = {} + losses_dict.update( + { + "cam_pts3d": { + "values": cam_pts3d_losses, + "use_mask": True, + "is_multi_view": True, + }, + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": False, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": False, + "is_multi_view": True, + }, + "scale": { + "values": scale_loss, + "use_mask": False, + "is_multi_view": False, + }, + "normal": { + "values": normal_losses, + "use_mask": False, + "is_multi_view": True, + }, + "gradient_matching": { + "values": gradient_matching_losses, + "use_mask": False, + "is_multi_view": True, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class DisentangledFactoredGeometryScaleRegr3D(Criterion, MultiLoss): + """ + Disentangled Regression Loss for Factored Geometry & Scale. + """ + + def __init__( + self, + criterion, + norm_predictions=True, + norm_mode="avg_dis", + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + scale_loss_weight=1, + ): + """ + Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale. + It isolates/disentangles the contribution of each factor to the final task of 3D reconstruction. + All the losses are in the same space where the loss for each factor is computed by constructing world-frame pointmaps. + This sidesteps the difficulty of finding a proper weighting. + For insance, for predicted rays, the GT depth & pose is used to construct the predicted world-frame pointmaps on which the loss is computed. + Inspired by https://openaccess.thecvf.com/content_ICCV_2019/papers/Simonelli_Disentangling_Monocular_3D_Object_Detection_ICCV_2019_paper.pdf + + The pixel-level losses are computed in the following order: + (1) depth, (2) ray directions, (3) pose quats, (4) pose trans, (5) scale. + The predicited scene representation is always normalized w.r.t. the frame of view0. + Loss is applied between the predicted metric scale and the ground truth metric scale. + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth, pointmaps and scale. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + """ + super().__init__(criterion) + self.norm_predictions = norm_predictions + self.norm_mode = norm_mode + self.loss_in_log = loss_in_log + self.flatten_across_image_only = flatten_across_image_only + self.depth_type_for_loss = depth_type_for_loss + assert self.depth_type_for_loss in [ + "depth_along_ray", + "depth_z", + ], "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" + self.depth_loss_weight = depth_loss_weight + self.ray_directions_loss_weight = ray_directions_loss_weight + self.pose_quats_loss_weight = pose_quats_loss_weight + self.pose_trans_loss_weight = pose_trans_loss_weight + self.scale_loss_weight = scale_loss_weight + + def get_all_info(self, batch, preds, dist_clip=None): + """ + Function to get all the information needed to compute the loss. + Returns all quantities normalized w.r.t. camera of view0. + """ + n_views = len(batch) + + # Everything is normalized w.r.t. camera of view0 + # Intialize lists to store data for all views + # Ground truth quantities + in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) + no_norm_gt_pts = [] + no_norm_gt_pts_cam = [] + no_norm_gt_depth = [] + no_norm_gt_pose_trans = [] + valid_masks = [] + gt_ray_directions = [] + gt_pose_quats = [] + # Predicted quantities + no_norm_pr_pts = [] + no_norm_pr_pts_cam = [] + no_norm_pr_depth = [] + no_norm_pr_pose_trans = [] + pr_ray_directions = [] + pr_pose_quats = [] + metric_pr_pts_to_compute_scale = [] + + # Get ground truth & prediction info for all views + for i in range(n_views): + # Get the ground truth + no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) + valid_masks.append(batch[i]["valid_mask"].clone()) + no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"]) + gt_ray_directions.append(batch[i]["ray_directions_cam"]) + if self.depth_type_for_loss == "depth_along_ray": + no_norm_gt_depth.append(batch[i]["depth_along_ray"]) + elif self.depth_type_for_loss == "depth_z": + no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:]) + if i == 0: + # For view0, initialize identity pose + gt_pose_quats.append( + torch.tensor( + [0, 0, 0, 1], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + no_norm_gt_pose_trans.append( + torch.tensor( + [0, 0, 0], + dtype=gt_ray_directions[0].dtype, + device=gt_ray_directions[0].device, + ) + .unsqueeze(0) + .repeat(gt_ray_directions[0].shape[0], 1) + ) + else: + # For other views, transform pose to view0's frame + gt_pose_quats_world = batch[i]["camera_pose_quats"] + no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"] + gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = ( + transform_pose_using_quats_and_trans_2_to_1( + batch[0]["camera_pose_quats"], + batch[0]["camera_pose_trans"], + gt_pose_quats_world, + no_norm_gt_pose_trans_world, + ) + ) + gt_pose_quats.append(gt_pose_quats_in_view0) + no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0) + + # Get predictions for normalized loss + if self.depth_type_for_loss == "depth_along_ray": + curr_view_no_norm_depth = preds[i]["depth_along_ray"] + elif self.depth_type_for_loss == "depth_z": + curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:] + if "metric_scaling_factor" in preds[i].keys(): + # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans + # This detaches the predicted metric scaling factor from the geometry based loss + curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][ + "metric_scaling_factor" + ].unsqueeze(-1).unsqueeze(-1) + curr_view_no_norm_pr_pose_trans = ( + preds[i]["cam_trans"] / preds[i]["metric_scaling_factor"] + ) + else: + curr_view_no_norm_pr_pts = preds[i]["pts3d"] + curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] + curr_view_no_norm_depth = curr_view_no_norm_depth + curr_view_no_norm_pr_pose_trans = preds[i]["cam_trans"] + no_norm_pr_pts.append(curr_view_no_norm_pr_pts) + no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam) + no_norm_pr_depth.append(curr_view_no_norm_depth) + no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans) + pr_ray_directions.append(preds[i]["ray_directions"]) + pr_pose_quats.append(preds[i]["cam_quats"]) + + # Get the predicted metric scale points + if "metric_scaling_factor" in preds[i].keys(): + # Detach the raw predicted points so that the scale loss is only applied to the scaling factor + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.detach() + * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1) + ) + else: + curr_view_metric_pr_pts_to_compute_scale = ( + curr_view_no_norm_pr_pts.clone() + ) + metric_pr_pts_to_compute_scale.append( + curr_view_metric_pr_pts_to_compute_scale + ) + + if dist_clip is not None: + # Points that are too far-away == invalid + for i in range(n_views): + dis = no_norm_gt_pts[i].norm(dim=-1) + valid_masks[i] = valid_masks[i] & (dis <= dist_clip) + + # Initialize normalized tensors + gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] + gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam] + gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth] + gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans] + + pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] + pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam] + pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth] + pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans] + + # Normalize the predicted points if specified + if self.norm_predictions: + pr_normalization_output = normalize_multiple_pointclouds( + no_norm_pr_pts, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_pts_norm = pr_normalization_output[:-1] + pr_norm_factor = pr_normalization_output[-1] + + # Normalize the ground truth points + gt_normalization_output = normalize_multiple_pointclouds( + no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True + ) + gt_pts_norm = gt_normalization_output[:-1] + gt_norm_factor = gt_normalization_output[-1] + + for i in range(n_views): + if self.norm_predictions: + # Assign the normalized predictions + pr_pts[i] = pr_pts_norm[i] + pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor + pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor + pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0] + else: + pr_pts[i] = no_norm_pr_pts[i] + pr_pts_cam[i] = no_norm_pr_pts_cam[i] + pr_depth[i] = no_norm_pr_depth[i] + pr_pose_trans[i] = no_norm_pr_pose_trans[i] + # Assign the normalized ground truth quantities + gt_pts[i] = gt_pts_norm[i] + gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor + gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor + gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0] + + # Get the mask indicating ground truth metric scale quantities + metric_scale_mask = batch[0]["is_metric_scale"] + valid_gt_norm_factor_mask = ( + gt_norm_factor[:, 0, 0, 0] > 1e-8 + ) # Mask out cases where depth for all views is invalid + valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask + + if valid_metric_scale_mask.any(): + # Compute the scale norm factor using the predicted metric scale points + metric_pr_normalization_output = normalize_multiple_pointclouds( + metric_pr_pts_to_compute_scale, + valid_masks, + self.norm_mode, + ret_factor=True, + ) + pr_metric_norm_factor = metric_pr_normalization_output[-1] + + # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities + gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask] + pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask] + else: + gt_metric_norm_factor = None + pr_metric_norm_factor = None + + # Get ambiguous masks + ambiguous_masks = [] + for i in range(n_views): + ambiguous_masks.append( + (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) + ) + + # Pack into info dicts + gt_info = [] + pred_info = [] + for i in range(n_views): + gt_info.append( + { + "ray_directions": gt_ray_directions[i], + self.depth_type_for_loss: gt_depth[i], + "pose_trans": gt_pose_trans[i], + "pose_quats": gt_pose_quats[i], + "pts3d": gt_pts[i], + "pts3d_cam": gt_pts_cam[i], + } + ) + pred_info.append( + { + "ray_directions": pr_ray_directions[i], + self.depth_type_for_loss: pr_depth[i], + "pose_trans": pr_pose_trans[i], + "pose_quats": pr_pose_quats[i], + "pts3d": pr_pts[i], + "pts3d_cam": pr_pts_cam[i], + } + ) + + return ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) + + def compute_loss(self, batch, preds, **kw): + ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + + for i in range(n_views): + # Get the GT factored quantities for the current view + gt_pts3d = gt_info[i]["pts3d"] + gt_ray_directions = gt_info[i]["ray_directions"] + gt_depth = gt_info[i][self.depth_type_for_loss] + gt_pose_trans = gt_info[i]["pose_trans"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Get the predicted factored quantities for the current view + pred_ray_directions = pred_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss] + pred_pose_trans = pred_info[i]["pose_trans"] + pred_pose_quats = pred_info[i]["pose_quats"] + + # Get the predicted world-frame pointmaps using the different factors + if self.depth_type_for_loss == "depth_along_ray": + pred_ray_directions_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + pred_ray_directions, + gt_depth, + gt_pose_trans, + gt_pose_quats, + ) + ) + pred_depth_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + pred_depth, + gt_pose_trans, + gt_pose_quats, + ) + ) + pred_pose_trans_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + gt_depth, + pred_pose_trans, + gt_pose_quats, + ) + ) + pred_pose_quats_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + gt_depth, + gt_pose_trans, + pred_pose_quats, + ) + ) + else: + raise NotImplementedError + + # Mask out the valid quantities as required + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]] + pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]] + pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]] + pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]] + gt_pts3d = gt_pts3d[valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, pts_dim = gt_pts3d.shape + pred_ray_directions_pts3d = pred_ray_directions_pts3d.view( + batch_size, -1, pts_dim + ) + pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim) + pred_pose_trans_pts3d = pred_pose_trans_pts3d.view( + batch_size, -1, pts_dim + ) + pred_pose_quats_pts3d = pred_pose_quats_pts3d.view( + batch_size, -1, pts_dim + ) + gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space if specified + if self.loss_in_log: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d) + pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d) + pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d) + pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + pose_quats_loss = self.criterion( + pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats" + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute the scale loss + if gt_metric_norm_factor is not None: + if self.loss_in_log: + gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) + pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) + scale_loss = ( + self.criterion( + pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" + ) + * self.scale_loss_weight + ) + else: + scale_loss = None + + # Use helper function to generate loss terms and details + losses_dict = {} + losses_dict.update( + { + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": True, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": True, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": True, + "is_multi_view": True, + }, + "scale": { + "values": scale_loss, + "use_mask": False, + "is_multi_view": False, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) + + +class DisentangledFactoredGeometryScaleRegr3DPlusNormalGMLoss( + DisentangledFactoredGeometryScaleRegr3D +): + """ + Disentangled Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale. + """ + + def __init__( + self, + criterion, + norm_predictions=True, + norm_mode="avg_dis", + loss_in_log=True, + flatten_across_image_only=False, + depth_type_for_loss="depth_along_ray", + depth_loss_weight=1, + ray_directions_loss_weight=1, + pose_quats_loss_weight=1, + pose_trans_loss_weight=1, + scale_loss_weight=1, + apply_normal_and_gm_loss_to_synthetic_data_only=True, + normal_loss_weight=1, + gm_loss_weight=1, + ): + """ + Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale. + See parent class (DisentangledFactoredGeometryScaleRegr3D) for more details. + Additionally computes: + (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates, + (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) + + Args: + criterion (BaseCriterion): The base criterion to use for computing the loss. + norm_predictions (bool): If True, normalize the predictions before computing the loss. + norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". + loss_in_log (bool): If True, apply logarithmic transformation to input before + computing the loss for depth, pointmaps and scale. Default: True. + flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing + the loss. If False, flatten across batch and spatial dimensions. Default: False. + depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". + Options: "depth_along_ray", "depth_z" + depth_loss_weight (float): Weight to use for the depth loss. Default: 1. + ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. + pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. + pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. + If False, apply the normal and gm loss to all data. Default: True. + normal_loss_weight (float): Weight to use for the normal loss. Default: 1. + gm_loss_weight (float): Weight to use for the gm loss. Default: 1. + """ + super().__init__( + criterion=criterion, + norm_predictions=norm_predictions, + norm_mode=norm_mode, + loss_in_log=loss_in_log, + flatten_across_image_only=flatten_across_image_only, + depth_type_for_loss=depth_type_for_loss, + depth_loss_weight=depth_loss_weight, + ray_directions_loss_weight=ray_directions_loss_weight, + pose_quats_loss_weight=pose_quats_loss_weight, + pose_trans_loss_weight=pose_trans_loss_weight, + scale_loss_weight=scale_loss_weight, + ) + self.apply_normal_and_gm_loss_to_synthetic_data_only = ( + apply_normal_and_gm_loss_to_synthetic_data_only + ) + self.normal_loss_weight = normal_loss_weight + self.gm_loss_weight = gm_loss_weight + + def compute_loss(self, batch, preds, **kw): + ( + gt_info, + pred_info, + valid_masks, + ambiguous_masks, + gt_metric_norm_factor, + pr_metric_norm_factor, + ) = self.get_all_info(batch, preds, **kw) + n_views = len(batch) + + normal_losses = [] + gradient_matching_losses = [] + pose_trans_losses = [] + pose_quats_losses = [] + ray_directions_losses = [] + depth_losses = [] + + for i in range(n_views): + # Get the camera frame points, log space depth_z & valid masks + pred_local_pts3d = pred_info[i]["pts3d_cam"] + pred_depth_z = pred_local_pts3d[..., 2:] + pred_depth_z = apply_log_to_norm(pred_depth_z) + gt_local_pts3d = gt_info[i]["pts3d_cam"] + gt_depth_z = gt_local_pts3d[..., 2:] + gt_depth_z = apply_log_to_norm(gt_depth_z) + valid_mask_for_normal_gm_loss = valid_masks[i].clone() + + # Update the validity mask for normal & gm loss based on the synthetic data mask if required + if self.apply_normal_and_gm_loss_to_synthetic_data_only: + synthetic_mask = batch[i]["is_synthetic"] # (B, ) + synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) + synthetic_mask = synthetic_mask.expand( + -1, pred_depth_z.shape[1], pred_depth_z.shape[2] + ) # (B, H, W) + valid_mask_for_normal_gm_loss = ( + valid_mask_for_normal_gm_loss & synthetic_mask + ) + + # Compute the normal loss + normal_loss = compute_normal_loss( + pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() + ) + normal_loss = normal_loss * self.normal_loss_weight + normal_losses.append(normal_loss) + + # Compute the gradient matching loss + gradient_matching_loss = compute_gradient_matching_loss( + pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() + ) + gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight + gradient_matching_losses.append(gradient_matching_loss) + + # Get the GT factored quantities for the current view + gt_pts3d = gt_info[i]["pts3d"] + gt_ray_directions = gt_info[i]["ray_directions"] + gt_depth = gt_info[i][self.depth_type_for_loss] + gt_pose_trans = gt_info[i]["pose_trans"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Get the predicted factored quantities for the current view + pred_ray_directions = pred_info[i]["ray_directions"] + pred_depth = pred_info[i][self.depth_type_for_loss] + pred_pose_trans = pred_info[i]["pose_trans"] + pred_pose_quats = pred_info[i]["pose_quats"] + + # Get the predicted world-frame pointmaps using the different factors + if self.depth_type_for_loss == "depth_along_ray": + pred_ray_directions_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + pred_ray_directions, + gt_depth, + gt_pose_trans, + gt_pose_quats, + ) + ) + pred_depth_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + pred_depth, + gt_pose_trans, + gt_pose_quats, + ) + ) + pred_pose_trans_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + gt_depth, + pred_pose_trans, + gt_pose_quats, + ) + ) + pred_pose_quats_pts3d = ( + convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + gt_ray_directions, + gt_depth, + gt_pose_trans, + pred_pose_quats, + ) + ) + else: + raise NotImplementedError + + # Mask out the valid quantities as required + if not self.flatten_across_image_only: + # Flatten the points across the entire batch with the masks + pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]] + pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]] + pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]] + pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]] + gt_pts3d = gt_pts3d[valid_masks[i]] + else: + # Flatten the H x W dimensions to H*W + batch_size, _, _, pts_dim = gt_pts3d.shape + pred_ray_directions_pts3d = pred_ray_directions_pts3d.view( + batch_size, -1, pts_dim + ) + pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim) + pred_pose_trans_pts3d = pred_pose_trans_pts3d.view( + batch_size, -1, pts_dim + ) + pred_pose_quats_pts3d = pred_pose_quats_pts3d.view( + batch_size, -1, pts_dim + ) + gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim) + valid_masks[i] = valid_masks[i].view(batch_size, -1) + + # Apply loss in log space if specified + if self.loss_in_log: + gt_pts3d = apply_log_to_norm(gt_pts3d) + pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d) + pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d) + pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d) + pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + pose_quats_loss = self.criterion( + pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats" + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute the scale loss + if gt_metric_norm_factor is not None: + if self.loss_in_log: + gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) + pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) + scale_loss = ( + self.criterion( + pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" + ) + * self.scale_loss_weight + ) + else: + scale_loss = None + + # Use helper function to generate loss terms and details + losses_dict = {} + losses_dict.update( + { + self.depth_type_for_loss: { + "values": depth_losses, + "use_mask": True, + "is_multi_view": True, + }, + "ray_directions": { + "values": ray_directions_losses, + "use_mask": True, + "is_multi_view": True, + }, + "pose_quats": { + "values": pose_quats_losses, + "use_mask": True, + "is_multi_view": True, + }, + "pose_trans": { + "values": pose_trans_losses, + "use_mask": True, + "is_multi_view": True, + }, + "scale": { + "values": scale_loss, + "use_mask": False, + "is_multi_view": False, + }, + "normal": { + "values": normal_losses, + "use_mask": False, + "is_multi_view": True, + }, + "gradient_matching": { + "values": gradient_matching_losses, + "use_mask": False, + "is_multi_view": True, + }, + } + ) + loss_terms, details = get_loss_terms_and_details( + losses_dict, + valid_masks, + type(self).__name__, + n_views, + self.flatten_across_image_only, + ) + losses = Sum(*loss_terms) + + return losses, (details | {}) diff --git a/mapanything/train/profile_dataloading.py b/mapanything/train/profile_dataloading.py new file mode 100644 index 0000000000000000000000000000000000000000..881eb21fe131d37d6331a0f503cbea4d5bac1eda --- /dev/null +++ b/mapanything/train/profile_dataloading.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Debug script to profile dataloading for MapAnything training. + +This script measures and analyzes the performance of data loading operations +for MapAnything training workflows. It simulates the training process without +actual model training to isolate and profile the data loading components. +""" + +import datetime +import json +import os +import time +from pathlib import Path +from typing import Sized + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter + +import mapanything.utils.train_tools as train_tools +from mapanything.datasets import get_test_data_loader, get_train_data_loader +from mapanything.datasets.base.base_dataset import view_name + +# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) +if hasattr(torch.backends.cuda, "matmul") and hasattr( + torch.backends.cuda.matmul, "allow_tf32" +): + torch.backends.cuda.matmul.allow_tf32 = True + + +def profile_dataloading(args): + """ + Main profiling function that simulates the training process to measure data loading performance. + + This function initializes the distributed environment, sets up datasets and data loaders, + and runs through training epochs to profile the data loading operations. It measures + the time taken for data loading without performing actual model training or optimization. + + In this simulation, an epoch represents a complete pass through a chunk of the dataset. + + Args: + args: Configuration object containing all parameters including: + - dataset: Dataset configuration (train_dataset, test_dataset, num_workers) + - train_params: Training parameters (batch_size, epochs, seed, etc.) + - distributed: Distributed training configuration + - output_dir: Directory for saving logs and profiling results + """ + # Initialize distributed training if required + train_tools.init_distributed_mode(args.distributed) + global_rank = train_tools.get_rank() + world_size = train_tools.get_world_size() # noqa + + # Init output directory and device + print("output_dir: " + args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(", ", ",\n")) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # Fix the seed + seed = args.train_params.seed + train_tools.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = not args.train_params.disable_cudnn_benchmark + + # Datasets and Dataloaders + print("Building train dataset {:s}".format(args.dataset.train_dataset)) + data_loader_train = build_dataset( + dataset=args.dataset.train_dataset, + num_workers=args.dataset.num_workers, + test=False, + max_num_of_imgs_per_gpu=args.train_params.max_num_of_imgs_per_gpu, + ) + print("Building test dataset {:s}".format(args.dataset.test_dataset)) + test_batch_size = 2 * ( + args.train_params.max_num_of_imgs_per_gpu // args.dataset.num_views + ) # Since we don't have any backward overhead + data_loader_test = { + dataset.split("(")[0]: build_dataset( + dataset=dataset, + num_workers=args.dataset.num_workers, + test=True, + batch_size=test_batch_size, + ) + for dataset in args.dataset.test_dataset.split("+") + if "(" in dataset + } + + def write_log_stats(epoch, train_stats, test_stats): + """ + Writes profiling statistics to log files and TensorBoard. + + This function collects metrics from the training and testing phases and writes them + to log files and TensorBoard for visualization and analysis. It only executes on the + main process in a distributed setting. + + Args: + epoch: int, current epoch number + train_stats: dict, containing training metrics and timing information + test_stats: dict, containing testing metrics for each test dataset + """ + if train_tools.is_main_process(): + if log_writer is not None: + log_writer.flush() + + log_stats = dict( + epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()} + ) + for test_name in data_loader_test: + if test_name not in test_stats: + continue + log_stats.update( + {test_name + "_" + k: v for k, v in test_stats[test_name].items()} + ) + + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training for {args.train_params.epochs} epochs") + start_time = time.time() + train_stats = test_stats = {} + args.train_params.start_epoch = 0 + for epoch in range(args.train_params.start_epoch, args.train_params.epochs + 1): + # Save more stuff + write_log_stats(epoch, train_stats, test_stats) + + if epoch >= args.train_params.epochs: + break # exit after writing last test to disk + + # Train + train_stats = train_one_epoch( + data_loader_train, + device, + epoch, + log_writer=log_writer, + args=args, + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + +def build_dataset( + dataset, num_workers, test, batch_size=None, max_num_of_imgs_per_gpu=None +): + """ + Builds data loaders for training or testing. + + Args: + dataset: Dataset specification string. + num_workers: Number of worker processes for data loading. + test: Boolean flag indicating whether this is a test dataset. + batch_size: Number of samples per batch. Defaults to None. Used only for testing. + max_num_of_imgs_per_gpu: Maximum number of images per GPU. Defaults to None. Used only for training. + + Returns: + DataLoader: PyTorch DataLoader configured for the specified dataset. + """ + split = ["Train", "Test"][test] + print(f"Building {split} Data loader for dataset: ", dataset) + if test: + assert batch_size is not None, ( + "batch_size must be specified for testing dataloader" + ) + loader = get_test_data_loader( + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_mem=True, + shuffle=False, + drop_last=False, + ) + else: + assert max_num_of_imgs_per_gpu is not None, ( + "max_num_of_imgs_per_gpu must be specified for training dataloader" + ) + loader = get_train_data_loader( + dataset=dataset, + max_num_of_imgs_per_gpu=max_num_of_imgs_per_gpu, + num_workers=num_workers, + pin_mem=True, + shuffle=True, + drop_last=True, + ) + + print(f"{split} dataset length: ", len(loader)) + return loader + + +def train_one_epoch( + data_loader: Sized, + device: torch.device, + epoch: int, + args, + log_writer=None, +): + """ + Simulates training for one epoch to profile data loading performance. + + This function runs through a single epoch, simulating the data loading and device transfer + operations that would occur during actual training. It measures and logs the time taken + for these operations without performing actual model training. + + Args: + data_loader: Sized, DataLoader providing the training data + device: torch.device, device to transfer data to (CPU or GPU) + epoch: int, current epoch number + args: object, configuration object containing training parameters including: + - train_params.print_freq: frequency of logging during the epoch + log_writer: Optional[SummaryWriter], TensorBoard SummaryWriter for logging metrics + + Returns: + dict: Dictionary containing profiling metrics averaged over the epoch + """ + metric_logger = train_tools.MetricLogger(delimiter=" ") + header = "Epoch: [{}]".format(epoch) + + if log_writer is not None: + print("log_dir: {}".format(log_writer.log_dir)) + + if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"): + data_loader.dataset.set_epoch(epoch) + if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"): + data_loader.sampler.set_epoch(epoch) + if hasattr(data_loader, "batch_sampler") and hasattr( + data_loader.batch_sampler, "set_epoch" + ): + data_loader.batch_sampler.set_epoch(epoch) + + for data_iter_step, batch in enumerate( + metric_logger.log_every(data_loader, args.train_params.print_freq, header) + ): + epoch_f = epoch + data_iter_step / len(data_loader) + + # Simulate the device loading in loss_of_one_batch_multi_view + ignore_keys = set( + [ + "depthmap", + "dataset", + "label", + "instance", + "idx", + "true_shape", + "rng", + "data_norm_type", + ] + ) + for view in batch: + for name in view.keys(): + if name in ignore_keys: + continue + view[name] = view[name].to(device, non_blocking=True) + + local_rank = train_tools.get_rank() + n_views = len(batch) + batch_shape = batch[0]["img"].shape + first_sample_name = view_name(batch[0], batch_index=0) + print( + f"Rank: {local_rank}, Num views: {n_views}, Batch Shape: {batch_shape}, First Sample Name: {first_sample_name}", + force=True, + ) + + del batch + + metric_logger.update(epoch=epoch_f) + metric_logger.update(loss=0) + + # # Gather the stats from all processes + # metric_logger.synchronize_between_processes() + # print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/mapanything/train/training.py b/mapanything/train/training.py new file mode 100644 index 0000000000000000000000000000000000000000..109057bc67022ead3634ada41458ed77b7403bd7 --- /dev/null +++ b/mapanything/train/training.py @@ -0,0 +1,664 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Training Code for MapAnything. + +References: +DUSt3R: https://github.com/naver/dust3r +""" + +import datetime +import json +import math +import os +import pickle +import sys +import time +from collections import defaultdict +from pathlib import Path +from typing import Sized + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter + +import mapanything.utils.train_tools as train_tools +from mapanything.datasets import get_test_data_loader, get_train_data_loader +from mapanything.models import init_model +from mapanything.train.losses import * # noqa +from mapanything.utils.inference import loss_of_one_batch_multi_view +from mapanything.utils.train_tools import NativeScalerWithGradNormCount as NativeScaler + +# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) +if hasattr(torch.backends.cuda, "matmul") and hasattr( + torch.backends.cuda.matmul, "allow_tf32" +): + torch.backends.cuda.matmul.allow_tf32 = True + + +def train(args): + """ + Main training function that handles the entire training process. + + This function initializes the distributed training environment, sets up datasets, + initializes the model, optimizer, and loss functions, and manages the training + and evaluation loop across multiple epochs. + + In this training, an epoch is just a chunk of the entire dataset. + + Args: + args: Configuration object containing all training parameters including + dataset configs, model configs, training parameters, and loss functions. + """ + # Initialize distributed training if required + train_tools.init_distributed_mode(args.distributed) + global_rank = train_tools.get_rank() + world_size = train_tools.get_world_size() # noqa + + # Init output directory and device + print("output_dir: " + args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(", ", ",\n")) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # Fix the seed + seed = args.train_params.seed + train_tools.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = not args.train_params.disable_cudnn_benchmark + + # Datasets and Dataloaders + print("Building train dataset {:s}".format(args.dataset.train_dataset)) + data_loader_train = build_dataset( + dataset=args.dataset.train_dataset, + num_workers=args.dataset.num_workers, + test=False, + max_num_of_imgs_per_gpu=args.train_params.max_num_of_imgs_per_gpu, + ) + print("Building test dataset {:s}".format(args.dataset.test_dataset)) + test_batch_size = 2 * ( + args.train_params.max_num_of_imgs_per_gpu // args.dataset.num_views + ) # Since we don't have any backward overhead + data_loader_test = { + dataset.split("(")[0]: build_dataset( + dataset=dataset, + num_workers=args.dataset.num_workers, + test=True, + batch_size=test_batch_size, + ) + for dataset in args.dataset.test_dataset.split("+") + if "(" in dataset + } + + # Load Model + if global_rank == 0: + model = init_model( + args.model.model_str, + args.model.model_config, + torch_hub_force_reload=args.model.torch_hub_force_reload, + ) + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Make sure the model is initialized before proceeding + if global_rank != 0: + model = init_model( + args.model.model_str, args.model.model_config, torch_hub_force_reload=False + ) + model.to(device) # Move model to device + model_without_ddp = model + print("Model = %s" % str(model_without_ddp)) + + # Criterion + print(f">> Creating train criterion = {args.loss.train_criterion}") + train_criterion = eval(args.loss.train_criterion).to(device) + print( + f">> Creating test criterion = {args.loss.test_criterion or args.loss.train_criterion}" + ) + test_criterion = eval(args.loss.test_criterion or args.loss.train_criterion).to( + device + ) + + # Load pretrained model if provided + if args.model.pretrained: + print("Loading pretrained: ", args.model.pretrained) + ckpt = torch.load( + args.model.pretrained, map_location=device, weights_only=False + ) + print(model.load_state_dict(ckpt["model"], strict=False)) + del ckpt # in case it occupies memory + + # Init model for DDP training + if args.distributed.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[args.distributed.gpu], + find_unused_parameters=True, + static_graph=False, + ) + model_without_ddp = model.module + + # Optimizer and loss scaler for gradient accumulation + # Following timm: set wd as 0 for bias and norm layers + param_groups, param_groups_name_to_idx_map, param_groups_idx_to_name_map = ( + train_tools.get_parameter_groups( + model_without_ddp, + args.train_params.lr, + args.train_params.weight_decay, + submodule_configs=args.train_params.submodule_configs, + warn_not_in_submodule=args.train_params.warn_not_in_submodule, + ) + ) + optimizer = torch.optim.AdamW( + param_groups, lr=args.train_params.lr, betas=(0.9, 0.95) + ) + print(optimizer) + loss_scaler = NativeScaler() + + def write_log_stats(epoch, train_stats, test_stats): + """ + Writes training and testing statistics to log files and TensorBoard. + + Args: + epoch: Current epoch number. + train_stats: Dictionary containing training metrics. + test_stats: Dictionary containing testing metrics for each test dataset. + """ + if train_tools.is_main_process(): + if log_writer is not None: + log_writer.flush() + + log_stats = dict( + epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()} + ) + for test_name in data_loader_test: + if test_name not in test_stats: + continue + log_stats.update( + {test_name + "_" + k: v for k, v in test_stats[test_name].items()} + ) + + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + def save_model(epoch, fname, best_so_far): + """ + Saves model checkpoint to disk. + + Args: + epoch: Current epoch number. + fname: Filename or identifier for the checkpoint. + best_so_far: Best validation metric achieved so far. + """ + train_tools.save_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + fname=fname, + best_so_far=best_so_far, + ) + + # Resume from a checkpoint if needed + last_ckpt_fname = os.path.join(args.output_dir, "checkpoint-last.pth") + if args.train_params.resume and os.path.isfile(last_ckpt_fname): + args.train_params.resume_ckpt = last_ckpt_fname + else: + args.train_params.resume_ckpt = None + best_so_far = train_tools.load_model( + train_args=args.train_params, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + ) + if best_so_far is None: + best_so_far = float("inf") + + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training for {args.train_params.epochs} epochs") + start_time = time.time() + train_stats = test_stats = {} + for epoch in range(args.train_params.start_epoch, args.train_params.epochs + 1): + # Save immediately the last checkpoint + if epoch > args.train_params.start_epoch: + if ( + args.train_params.save_freq + and epoch % args.train_params.save_freq == 0 + or epoch == args.train_params.epochs + ): + save_model(epoch - 1, "last", best_so_far) + + # Test on multiple datasets + new_best = False + test_stats = {} + if ( + args.train_params.eval_freq > 0 + and epoch % args.train_params.eval_freq == 0 + and epoch > 0 + ): + for test_name, testset in data_loader_test.items(): + print(f"Testing on {test_name} ...") + stats = test_one_epoch( + model, + test_criterion, + testset, + device, + epoch, + log_writer=log_writer, + args=args, + prefix=test_name, + ) + test_stats[test_name] = stats + + # Calculate average test loss median + avg_test_loss_med = np.mean( + [stats["loss_med"] for stats in test_stats.values()] + ) + test_stats["Average Test Loss Median"] = avg_test_loss_med + # Save best + if avg_test_loss_med < best_so_far: + best_so_far = avg_test_loss_med + new_best = True + + # Save more stuff + write_log_stats(epoch, train_stats, test_stats) + + if epoch > args.train_params.start_epoch: + if args.train_params.keep_freq and epoch % args.train_params.keep_freq == 0: + save_model(epoch - 1, str(epoch), best_so_far) + if new_best: + save_model(epoch - 1, "best", best_so_far) + if epoch >= args.train_params.epochs: + break # exit after writing last test to disk + + # Train + train_stats = train_one_epoch( + model, + train_criterion, + data_loader_train, + optimizer, + device, + epoch, + loss_scaler, + log_writer=log_writer, + args=args, + param_groups_name_to_idx_map=param_groups_name_to_idx_map, + param_groups_idx_to_name_map=param_groups_idx_to_name_map, + model_without_ddp=model_without_ddp, + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + save_final_model( + args, args.train_params.epochs, model_without_ddp, best_so_far=best_so_far + ) + + +def save_final_model(args, epoch, model_without_ddp, best_so_far=None): + """ + Saves the final model checkpoint after training completion. + + Args: + args: Configuration object containing output directory information. + epoch: Current epoch number. + model_without_ddp: Model state dictionary or model instance without DistributedDataParallel wrapper. + best_so_far: Optional; Best validation metric achieved during training. + """ + output_dir = Path(args.output_dir) + checkpoint_path = output_dir / "checkpoint-final.pth" + to_save = { + "args": args, + "model": model_without_ddp + if isinstance(model_without_ddp, dict) + else model_without_ddp.cpu().state_dict(), + "epoch": epoch, + } + if best_so_far is not None: + to_save["best_so_far"] = best_so_far + print(f">> Saving model to {checkpoint_path} ...") + train_tools.save_on_master(to_save, checkpoint_path) + + +def build_dataset( + dataset, num_workers, test, batch_size=None, max_num_of_imgs_per_gpu=None +): + """ + Builds data loaders for training or testing. + + Args: + dataset: Dataset specification string. + num_workers: Number of worker processes for data loading. + test: Boolean flag indicating whether this is a test dataset. + batch_size: Number of samples per batch. Defaults to None. Used only for testing. + max_num_of_imgs_per_gpu: Maximum number of images per GPU. Defaults to None. Used only for training. + + Returns: + DataLoader: PyTorch DataLoader configured for the specified dataset. + """ + split = ["Train", "Test"][test] + print(f"Building {split} Data loader for dataset: ", dataset) + if test: + assert batch_size is not None, ( + "batch_size must be specified for testing dataloader" + ) + loader = get_test_data_loader( + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_mem=True, + shuffle=False, + drop_last=False, + ) + else: + assert max_num_of_imgs_per_gpu is not None, ( + "max_num_of_imgs_per_gpu must be specified for training dataloader" + ) + loader = get_train_data_loader( + dataset=dataset, + max_num_of_imgs_per_gpu=max_num_of_imgs_per_gpu, + num_workers=num_workers, + pin_mem=True, + shuffle=True, + drop_last=True, + ) + + print(f"{split} dataset length: ", len(loader)) + return loader + + +def train_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + data_loader: Sized, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int, + loss_scaler, + args, + log_writer=None, + param_groups_name_to_idx_map=None, + param_groups_idx_to_name_map=None, + model_without_ddp=None, +): + """ + Trains the model for one epoch. + Epoch is just a chunk of the entire dataset. + + This function handles the training loop for a single epoch, including forward/backward passes, + gradient accumulation, learning rate scheduling, and logging metrics. + + Args: + model: The neural network model to train. + criterion: Loss function to optimize. + data_loader: DataLoader providing the training data. + optimizer: Optimizer for updating model parameters. + device: Device to run training on (CPU or GPU). + epoch: Current epoch number. + loss_scaler: Scaler for gradient accumulation and mixed precision training. + args: Configuration object containing training parameters. + log_writer: Optional; TensorBoard SummaryWriter for logging. + param_groups_name_to_idx_map: Mapping from parameter group names to indices. + param_groups_idx_to_name_map: Mapping from parameter group indices to names. + model_without_ddp: Model without DistributedDataParallel wrapper for debugging. + + Returns: + dict: Dictionary containing training metrics averaged over the epoch. + """ + model.train(True) + metric_logger = train_tools.MetricLogger(delimiter=" ") + for submodule_name in param_groups_name_to_idx_map: + lr_name = f"lr_{submodule_name}" if submodule_name != "default" else "lr" + metric_logger.add_meter( + lr_name, train_tools.SmoothedValue(window_size=1, fmt="{value:.6f}") + ) + header = "Epoch: [{}]".format(epoch) + accum_iter = args.train_params.accum_iter + + if log_writer is not None: + print("log_dir: {}".format(log_writer.log_dir)) + + if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"): + data_loader.dataset.set_epoch(epoch) + if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"): + data_loader.sampler.set_epoch(epoch) + if hasattr(data_loader, "batch_sampler") and hasattr( + data_loader.batch_sampler, "set_epoch" + ): + data_loader.batch_sampler.set_epoch(epoch) + + optimizer.zero_grad() + + for data_iter_step, batch in enumerate( + metric_logger.log_every(data_loader, args.train_params.print_freq, header) + ): + n_views = len(batch) + epoch_f = epoch + data_iter_step / len(data_loader) + + # We use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + train_tools.adjust_learning_rate( + optimizer, + epoch_f, + args.train_params, + param_groups_idx_to_name_map, + args.train_params.submodule_configs, + ) + + loss_tuple = loss_of_one_batch_multi_view( + batch, + model, + criterion, + device, + use_amp=bool(args.train_params.amp), + amp_dtype=args.train_params.amp_dtype, + ret="loss", + ) + loss, loss_details = loss_tuple # criterion returns two values + if n_views > 2: + loss = loss * ( + 2 / n_views + ) # scale the loss relative to the number of views (base is 2 views) + loss_value = float(loss) + + if not math.isfinite(loss_value) or (loss_value > 1000): + print("Loss is {}, stopping training".format(loss_value), force=True) + print(f"Loss Details: {loss_details}", force=True) + print(f"Epoch: {epoch}, Data Iteration: {data_iter_step}", force=True) + # Save the current batch to the output folder for further inspection + for view_idx, view in enumerate(batch): + view_cpu = {} + for k, v in view.items(): + view_cpu[k] = v.cpu() if isinstance(v, torch.Tensor) else v + with open( + os.path.join(args.output_dir, f"batch_view_{view_idx}.pkl"), "wb" + ) as f: + pickle.dump(view_cpu, f) + # Save the model to the output folder for further inspection + checkpoint_debug_path = os.path.join( + args.output_dir, "checkpoint-debug.pth" + ) + to_save_debug = { + "args": args, + "model": ( + model_without_ddp + if isinstance(model_without_ddp, dict) + else model_without_ddp.cpu().state_dict() + ), + "epoch": epoch, + "data_iter_step": data_iter_step, + } + torch.save(to_save_debug, checkpoint_debug_path) + print(f"Saved debugging material to {args.output_dir}", force=True) + sys.exit(1) + + # Scale the loss by the number of gradient accumulation iterations + loss /= accum_iter + + # Compute the scaled gradients (also clip the gradients to max norm of 1) + gradient_norm = loss_scaler( + loss, + optimizer, + parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0, + clip_grad=1.0, + ) + + # Zero out the gradients to prepare for the next iteration of gradient descent + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + del loss + del batch + + metric_logger.update(epoch=epoch_f) + for submodule_name in param_groups_name_to_idx_map: + lr_name = f"lr_{submodule_name}" if submodule_name != "default" else "lr" + log_lr = optimizer.param_groups[ + param_groups_name_to_idx_map[submodule_name][0] + ]["lr"] + metric_logger.meters[lr_name].update(log_lr) + metric_logger.update(loss=loss_value, **loss_details) + + if (data_iter_step + 1) % accum_iter == 0 and ( + (data_iter_step + 1) % (accum_iter * args.train_params.print_freq) + ) == 0: + loss_value_reduce = train_tools.all_reduce_mean( + loss_value + ) # MUST BE EXECUTED BY ALL NODES + if log_writer is None: + continue + """ + We use epoch_1000x as the x-axis in tensorboard. + This calibrates different curves when batch size changes. + """ + epoch_1000x = int(epoch_f * 1000) + log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x) + if gradient_norm is not None: + log_writer.add_scalar("train_grad_norm", gradient_norm, epoch_1000x) + for submodule_name in param_groups_name_to_idx_map: + lr_name = ( + f"train_lr_{submodule_name}" + if submodule_name != "default" + else "train_lr" + ) + log_lr = optimizer.param_groups[ + param_groups_name_to_idx_map[submodule_name][0] + ]["lr"] + log_writer.add_scalar(lr_name, log_lr, epoch_1000x) + log_writer.add_scalar("train_iter", epoch_1000x, epoch_1000x) + for name, val in loss_details.items(): + log_writer.add_scalar("train_" + name, val, epoch_1000x) + + # # Gather the stats from all processes + # metric_logger.synchronize_between_processes() + # print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def test_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + data_loader: Sized, + device: torch.device, + epoch: int, + args, + log_writer=None, + prefix="test", +): + """ + Evaluates the model on a test dataset for one epoch. + Epoch is just a chunk of the entire dataset. + + This function runs evaluation on the test dataset without computing gradients, + and collects metrics for model performance assessment. + + Args: + model: The neural network model to evaluate. + criterion: Loss function for evaluation. + data_loader: DataLoader providing the test data. + device: Device to run evaluation on (CPU or GPU). + epoch: Current epoch number. + args: Configuration object containing evaluation parameters. + log_writer: Optional; TensorBoard SummaryWriter for logging. + prefix: String prefix for logging metrics. + + Returns: + dict: Dictionary containing evaluation metrics (average and median values). + """ + model.eval() + metric_logger = train_tools.MetricLogger(delimiter=" ") + metric_logger.meters = defaultdict( + lambda: train_tools.SmoothedValue(window_size=9**9) + ) + header = "Test Epoch: [{}]".format(epoch) + + if log_writer is not None: + print("log_dir: {}".format(log_writer.log_dir)) + + if args.train_params.freeze_val_samples_across_all_epochs: + dataloader_epoch = 0 + else: + dataloader_epoch = epoch + if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"): + data_loader.dataset.set_epoch(dataloader_epoch) + if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"): + data_loader.sampler.set_epoch(dataloader_epoch) + if hasattr(data_loader, "batch_sampler") and hasattr( + data_loader.batch_sampler, "set_epoch" + ): + data_loader.batch_sampler.set_epoch(dataloader_epoch) + + for _, batch in enumerate( + metric_logger.log_every(data_loader, args.train_params.print_freq, header) + ): + n_views = len(batch) + loss_tuple = loss_of_one_batch_multi_view( + batch, + model, + criterion, + device, + use_amp=bool(args.train_params.amp), + amp_dtype=args.train_params.amp_dtype, + ret="loss", + ) + loss_value, loss_details = loss_tuple # criterion returns two values + if n_views > 2: + loss_value = loss_value * ( + 2 / n_views + ) # scale the loss relative to the number of views (base is 2 views) + metric_logger.update(loss=float(loss_value), **loss_details) + + # # Gather the stats from all processes + # metric_logger.synchronize_between_processes() + # print("Averaged stats:", metric_logger) + + aggs = [("avg", "global_avg"), ("med", "median")] + results = { + f"{k}_{tag}": getattr(meter, attr) + for k, meter in metric_logger.meters.items() + for tag, attr in aggs + } + + if log_writer is not None: + for name, val in results.items(): + log_writer.add_scalar(prefix + "_" + name, val, 1000 * epoch) + + return results diff --git a/mapanything/utils/__init__.py b/mapanything/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/utils/__pycache__/__init__.cpython-312.pyc b/mapanything/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37b07eef2465cac9967f2db197a9e1f13df33dae Binary files /dev/null and b/mapanything/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/cropping.cpython-312.pyc b/mapanything/utils/__pycache__/cropping.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..561d5b0d3c3296ce4d2144e7ce465905a0db719c Binary files /dev/null and b/mapanything/utils/__pycache__/cropping.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/geometry.cpython-312.pyc b/mapanything/utils/__pycache__/geometry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d4e2645f45007db1fd058b3b0d0a0148a419e82 Binary files /dev/null and b/mapanything/utils/__pycache__/geometry.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/image.cpython-312.pyc b/mapanything/utils/__pycache__/image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4e770980ec6ce01e10d5ffe694be00d1832b653 Binary files /dev/null and b/mapanything/utils/__pycache__/image.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/inference.cpython-312.pyc b/mapanything/utils/__pycache__/inference.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..115b694644f562cd681ccd9e49d4f9f711b03a34 Binary files /dev/null and b/mapanything/utils/__pycache__/inference.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/misc.cpython-312.pyc b/mapanything/utils/__pycache__/misc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17ef08ae32e3dd1cb7d34f52833c28abd8c3c03f Binary files /dev/null and b/mapanything/utils/__pycache__/misc.cpython-312.pyc differ diff --git a/mapanything/utils/__pycache__/warnings.cpython-312.pyc b/mapanything/utils/__pycache__/warnings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b8b9aec877c55d5ece038971032547c7b481f24 Binary files /dev/null and b/mapanything/utils/__pycache__/warnings.cpython-312.pyc differ diff --git a/mapanything/utils/colmap.py b/mapanything/utils/colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..c1965c0f345088dfd7666e4c455d83e9defa7257 --- /dev/null +++ b/mapanything/utils/colmap.py @@ -0,0 +1,662 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +""" +COLMAP Utils + +Source: https://github.com/colmap/colmap/blob/master/scripts/python/read_write_model.py (modified by Khiem Vuong) +""" + +import argparse +import collections +import os +import struct + +import numpy as np + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"] +) +Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] +) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] +) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), +} +CAMERA_MODEL_IDS = dict( + [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] +) +CAMERA_MODEL_NAMES = dict( + [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] +) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera( + id=camera_id, model=model, width=width, height=height, params=params + ) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for camera_line_index in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ" + ) + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes( + fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params + ) + cameras[camera_id] = Camera( + id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params), + ) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = [ + "# Camera list with one line of data per camera:\n" + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + "# Number of cameras: {}\n".format(len(cameras)) + ] + with open(path, "w") as fid: + fid.write("".join(HEADER)) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, model_id, cam.width, cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack( + [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] + ) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for image_index in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi" + ) + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] + x_y_id_s = read_next_bytes( + fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D, + ) + xys = np.column_stack( + [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] + ) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def write_images_text(images, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum( + (len(img.point3D_ids) for _, img in images.items()) + ) / len(images) + HEADER = [ + "# Image list with two lines of data per image:\n" + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + "# Number of images: {}, mean observations per image: {}\n".format( + len(images), mean_observations + ) + ] + + with open(path, "w") as fid: + fid.write("".join(HEADER)) + for _, img in images.items(): + image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def read_points3d_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for point_line_index in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd" + ) + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] + track_elems = read_next_bytes( + fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * track_length, + ) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum( + (len(pt.image_ids) for _, pt in points3D.items()) + ) / len(points3D) + HEADER = [ + "# 3D point list with one line of data per point:\n" + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + "# Number of points: {}, mean track length: {}\n".format( + len(points3D), mean_track_length + ) + ] + + with open(path, "w") as fid: + fid.write("".join(HEADER)) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3d_binary(points3D, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def read_model(path, ext): + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3d_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def get_camera_matrix(camera_params, camera_model): + """Get camera matrix and distortion coefficients from camera parameters in COLMAP format + Arguments + camera_params - Camera parameters in COLMAP format + camera_model - Camera model + Return + K - [3, 3] Camera matrix + dc - [12,] Distortion coefficients + """ + camera_params = np.asarray(camera_params) + K = np.zeros([3, 3]) + K[2, 2] = 1 + dc = np.zeros( + [ + 12, + ] + ) + if str.upper(camera_model) == "SIMPLE_PINHOLE": + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[0] + K[0, 2] = camera_params[1] + K[1, 2] = camera_params[2] + elif str.upper(camera_model) == "PINHOLE": + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[1] + K[0, 2] = camera_params[2] + K[1, 2] = camera_params[3] + elif str.upper(camera_model) in ("SIMPLE_RADIAL", "SIMPLE_RADIAL_FISHEYE"): + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[0] + K[0, 2] = camera_params[1] + K[1, 2] = camera_params[2] + dc[0] = camera_params[3] + elif str.upper(camera_model) in ("RADIAL", "RADIAL_FISHEYE"): + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[0] + K[0, 2] = camera_params[1] + K[1, 2] = camera_params[2] + dc[0:2] = camera_params[3:5] + elif str.upper(camera_model) == "OPENCV": + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[1] + K[0, 2] = camera_params[2] + K[1, 2] = camera_params[3] + dc[0:4] = camera_params[4:8] + elif str.upper(camera_model) == "FULL_OPENCV": + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[1] + K[0, 2] = camera_params[2] + K[1, 2] = camera_params[3] + dc[0:8] = camera_params[4:12] + elif str.upper(camera_model) == "OPENCV_FISHEYE": + K[0, 0] = camera_params[0] + K[1, 1] = camera_params[1] + K[0, 2] = camera_params[2] + K[1, 2] = camera_params[3] + dc[0:2] = camera_params[4:6] + dc[4:6] = camera_params[6:8] + else: + raise ValueError("Unsupported camera model: " + camera_model) + return K, dc + + +def main(): + parser = argparse.ArgumentParser( + description="Read and write COLMAP binary and text models" + ) + parser.add_argument("--input_model", help="path to input model folder") + parser.add_argument( + "--input_format", + choices=[".bin", ".txt"], + help="input model format", + default="", + ) + parser.add_argument("--output_model", help="path to output model folder") + parser.add_argument( + "--output_format", + choices=[".bin", ".txt"], + help="output model format", + default=".txt", + ) + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model( + cameras, images, points3D, path=args.output_model, ext=args.output_format + ) + + +def main(): # noqa + parser = argparse.ArgumentParser( + description="Read and write COLMAP binary and text models" + ) + parser.add_argument("input_model", help="path to input model folder") + parser.add_argument( + "input_format", choices=[".bin", ".txt"], help="input model format" + ) + parser.add_argument( + "--output_model", metavar="PATH", help="path to output model folder" + ) + parser.add_argument( + "--output_format", + choices=[".bin", ".txt"], + help="output model format", + default=".txt", + ) + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model( + cameras, images, points3D, path=args.output_model, ext=args.output_format + ) + + +if __name__ == "__main__": + main() diff --git a/mapanything/utils/cropping.py b/mapanything/utils/cropping.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6c011f5b755f1a73907e95bd3f1f5f14dd3e78 --- /dev/null +++ b/mapanything/utils/cropping.py @@ -0,0 +1,467 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utility functions for cropping and resizing data while maintaining proper cameras. + +References: DUSt3R +""" + +import cv2 +import numpy as np +import PIL.Image + +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + +from mapanything.utils.geometry import ( + colmap_to_opencv_intrinsics, + opencv_to_colmap_intrinsics, +) + + +class ImageList: + """ + Convenience class to apply the same operation to a whole set of images. + + This class wraps a list of PIL.Image objects and provides methods to perform + operations on all images simultaneously. + """ + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + """Return the number of images in the list.""" + return len(self.images) + + def to_pil(self): + """ + Convert ImageList back to PIL Image(s). + + Returns: + PIL.Image.Image or tuple: Single PIL Image if list contains one image, + or tuple of PIL Images if multiple images + """ + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + """ + Get the size of images in the list. + + Returns: + tuple: (width, height) of the images + + Raises: + AssertionError: If images have different sizes + """ + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes), "All images must have the same size" + return sizes[0] + + def resize(self, *args, **kwargs): + """ + Resize all images with the same parameters. + + Args: + *args, **kwargs: Arguments passed to PIL.Image.resize() + + Returns: + ImageList: New ImageList containing resized images + """ + return ImageList(self._dispatch("resize", *args, **kwargs)) + + def crop(self, *args, **kwargs): + """ + Crop all images with the same parameters. + + Args: + *args, **kwargs: Arguments passed to PIL.Image.crop() + + Returns: + ImageList: New ImageList containing cropped images + """ + return ImageList(self._dispatch("crop", *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + """ + Apply a PIL.Image method to all images in the list. + + Args: + func (str): Name of the PIL.Image method to call + *args, **kwargs: Arguments to pass to the method + + Returns: + list: List of results from applying the method to each image + """ + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def resize_with_nearest_interpolation_to_match_aspect_ratio(input_data, img_h, img_w): + """ + Resize input map to match the aspect ratio of an image while ensuring + the input resolution never increases beyond the original. + Uses nearest interpolation for resizing. + + Args: + input_data (np.ndarray): The input map to resize + img_h (int): Height of the target image + img_w (int): Width of the target image + + Returns: + tuple: (resized_input, target_h, target_w) + - resized_input: The resized input map + - target_h: The target height used for resizing + - target_w: The target width used for resizing + """ + # Get the dimensions of the input map + input_h, input_w = input_data.shape[:2] + + # Calculate aspect ratios + img_aspect = img_w / img_h + + # Option 1: Keep input_w fixed and calculate new height + option1_h = int(input_w / img_aspect) + # Option 2: Keep input_h fixed and calculate new width + option2_w = int(input_h * img_aspect) + + # Check if either option would increase a dimension + option1_increases = option1_h > input_h + option2_increases = option2_w > input_w + + if option1_increases and option2_increases: + # Both options would increase a dimension, so we need to scale down both dimensions + # Find the scaling factor that preserves aspect ratio and ensures no dimension increases + scale_h = input_h / img_h + scale_w = input_w / img_w + scale = min(scale_h, scale_w) + + target_input_h = int(img_h * scale) + target_input_w = int(img_w * scale) + elif option1_increases: + # Option 1 would increase height, so use option 2 + target_input_h = input_h + target_input_w = option2_w + elif option2_increases: + # Option 2 would increase width, so use option 1 + target_input_w = input_w + target_input_h = option1_h + else: + # Neither option increases dimensions, choose the one that maintains resolution better + if abs(input_h * input_w - input_w * option1_h) < abs( + input_h * input_w - option2_w * input_h + ): + # Option 1 is better: keep width fixed, adjust height + target_input_w = input_w + target_input_h = option1_h + else: + # Option 2 is better: keep height fixed, adjust width + target_input_h = input_h + target_input_w = option2_w + + # Resize input using nearest interpolation to maintain input values + if target_input_h != input_h or target_input_w != input_w: + resized_input = cv2.resize( + input_data, + (target_input_w, target_input_h), + interpolation=cv2.INTER_NEAREST, + ) + else: + resized_input = input_data + + return resized_input, target_input_h, target_input_w + + +def rescale_image_and_other_optional_info( + image, + output_resolution, + depthmap=None, + camera_intrinsics=None, + force=True, + additional_quantities_to_be_resized_with_nearest=None, +): + """ + Rescale the image and depthmap to the output resolution. + If the image is larger than the output resolution, it is rescaled with lanczos interpolation. + If force is false and the image is smaller than the output resolution, it is not rescaled. + If force is true and the image is smaller than the output resolution, it is rescaled with bicubic interpolation. + Depth and other quantities are rescaled with nearest interpolation. + + Args: + image (PIL.Image.Image or np.ndarray): The input image to be rescaled. + output_resolution (tuple): The desired output resolution as a tuple (width, height). + depthmap (np.ndarray, optional): The depth map associated with the image. Defaults to None. + camera_intrinsics (np.ndarray, optional): The camera intrinsics matrix. Defaults to None. + force (bool, optional): If True, force rescaling even if the image is smaller than the output resolution. Defaults to True. + additional_quantities_to_be_resized_with_nearest (list of np.ndarray, optional): Additional quantities to be rescaled using nearest interpolation. Defaults to None. + + Returns: + tuple: A tuple containing: + - The rescaled image (PIL.Image.Image) + - The rescaled depthmap (numpy.ndarray or None) + - The updated camera intrinsics (numpy.ndarray or None) + - The list of rescaled additional quantities (list of numpy.ndarray or None) + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W, H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + assert tuple(depthmap.shape[:2]) == image.size[::-1] + if additional_quantities_to_be_resized_with_nearest is not None: + assert all( + tuple(additional_quantity.shape[:2]) == image.size[::-1] + for additional_quantity in additional_quantities_to_be_resized_with_nearest + ) + + # Define output resolution + assert output_resolution.shape == (2,) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + output = ( + image.to_pil(), + depthmap, + camera_intrinsics, + additional_quantities_to_be_resized_with_nearest, + ) + return output + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + # First rescale the image so that it contains the crop + image = image.resize( + tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic + ) + if depthmap is not None: + depthmap = cv2.resize( + depthmap, + output_resolution, + fx=scale_final, + fy=scale_final, + interpolation=cv2.INTER_NEAREST, + ) + if additional_quantities_to_be_resized_with_nearest is not None: + resized_additional_quantities = [] + for quantity in additional_quantities_to_be_resized_with_nearest: + resized_additional_quantities.append( + cv2.resize( + quantity, + output_resolution, + fx=scale_final, + fy=scale_final, + interpolation=cv2.INTER_NEAREST, + ) + ) + additional_quantities_to_be_resized_with_nearest = resized_additional_quantities + + # No offset here; simple rescaling + if camera_intrinsics is not None: + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final + ) + + # Return + return ( + image.to_pil(), + depthmap, + camera_intrinsics, + additional_quantities_to_be_resized_with_nearest, + ) + + +def camera_matrix_of_crop( + input_camera_matrix, + input_resolution, + output_resolution, + scaling=1, + offset_factor=0.5, + offset=None, +): + """ + Calculate the camera matrix for a cropped image. + + Args: + input_camera_matrix (numpy.ndarray): Original camera intrinsics matrix + input_resolution (tuple or numpy.ndarray): Original image resolution as (width, height) + output_resolution (tuple or numpy.ndarray): Target image resolution as (width, height) + scaling (float, optional): Scaling factor for the image. Defaults to 1. + offset_factor (float, optional): Factor to determine crop offset. Defaults to 0.5 (centered). + offset (tuple or numpy.ndarray, optional): Explicit offset to use. If None, calculated from offset_factor. + + Returns: + numpy.ndarray: Updated camera matrix for the cropped image + """ + # Margins to offset the origin + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + # Generate new camera parameters + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_and_other_optional_info( + image, + crop_bbox, + depthmap=None, + camera_intrinsics=None, + additional_quantities=None, +): + """ + Return a crop of the input view and associated data. + + Args: + image (PIL.Image.Image or numpy.ndarray): The input image to be cropped + crop_bbox (tuple): Crop bounding box as (left, top, right, bottom) + depthmap (numpy.ndarray, optional): Depth map associated with the image + camera_intrinsics (numpy.ndarray, optional): Camera intrinsics matrix + additional_quantities (list of numpy.ndarray, optional): Additional data arrays to crop + + Returns: + tuple: A tuple containing: + - The cropped image + - The cropped depth map (if provided or None) + - Updated camera intrinsics (if provided or None) + - List of cropped additional quantities (if provided or None) + """ + image = ImageList(image) + left, top, right, bottom = crop_bbox + + image = image.crop((left, top, right, bottom)) + if depthmap is not None: + depthmap = depthmap[top:bottom, left:right] + if additional_quantities is not None: + additional_quantities = [ + quantity[top:bottom, left:right] for quantity in additional_quantities + ] + + if camera_intrinsics is not None: + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= left + camera_intrinsics[1, 2] -= top + + return (image.to_pil(), depthmap, camera_intrinsics, additional_quantities) + + +def bbox_from_intrinsics_in_out( + input_camera_matrix, output_camera_matrix, output_resolution +): + """ + Calculate the bounding box for cropping based on input and output camera intrinsics. + + Args: + input_camera_matrix (numpy.ndarray): Original camera intrinsics matrix + output_camera_matrix (numpy.ndarray): Target camera intrinsics matrix + output_resolution (tuple): Target resolution as (width, height) + + Returns: + tuple: Crop bounding box as (left, top, right, bottom) + """ + out_width, out_height = output_resolution + left, top = np.int32( + np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]) + ) + crop_bbox = (left, top, left + out_width, top + out_height) + return crop_bbox + + +def crop_resize_if_necessary( + image, + resolution, + depthmap=None, + intrinsics=None, + additional_quantities=None, +): + """ + First downsample image using LANCZOS and then crop if necessary to achieve target resolution. + + This function performs high-quality downsampling followed by cropping to achieve the + desired output resolution while maintaining proper camera intrinsics. + + Args: + image (PIL.Image.Image or numpy.ndarray): The input image to be processed + resolution (tuple): Target resolution as (width, height) + depthmap (numpy.ndarray, optional): Depth map associated with the image + intrinsics (numpy.ndarray, optional): Camera intrinsics matrix + additional_quantities (list of numpy.ndarray, optional): Additional data arrays to process + + Returns: + tuple: A tuple containing the processed image and any provided additional data + (depthmap, intrinsics, additional_quantities) that have been similarly processed + """ + # Convert image to PIL.Image.Image if necessary + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # Get width and height of image + original_width, original_height = image.size + + # High-quality Lanczos down-scaling + target_rescale_resolution = np.array(resolution) + image, depthmap, intrinsics, additional_quantities = ( + rescale_image_and_other_optional_info( + image=image, + output_resolution=target_rescale_resolution, + depthmap=depthmap, + camera_intrinsics=intrinsics, + additional_quantities_to_be_resized_with_nearest=additional_quantities, + ) + ) + + # Actual cropping (if necessary) + if intrinsics is not None: + new_intrinsics = camera_matrix_of_crop( + input_camera_matrix=intrinsics, + input_resolution=image.size, + output_resolution=resolution, + offset_factor=0.5, + ) + crop_bbox = bbox_from_intrinsics_in_out( + input_camera_matrix=intrinsics, + output_camera_matrix=new_intrinsics, + output_resolution=resolution, + ) + else: + # Create a centered crop if no intrinsics are available + w, h = image.size + target_w, target_h = resolution + left = (w - target_w) // 2 + top = (h - target_h) // 2 + crop_bbox = (left, top, left + target_w, top + target_h) + + image, depthmap, new_intrinsics, additional_quantities = ( + crop_image_and_other_optional_info( + image=image, + crop_bbox=crop_bbox, + depthmap=depthmap, + camera_intrinsics=intrinsics, + additional_quantities=additional_quantities, + ) + ) + + # Return the output + output = (image,) + if depthmap is not None: + output += (depthmap,) + if new_intrinsics is not None: + output += (new_intrinsics,) + if additional_quantities is not None: + output += (additional_quantities,) + return output diff --git a/mapanything/utils/device.py b/mapanything/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..d28180666c386af9b675fd8dc29be28fe42d9f9f --- /dev/null +++ b/mapanything/utils/device.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utility functions for managing computation device +""" + +import numpy as np +import torch + + +def to_device(batch, device, callback=None, non_blocking=False): + """ + Transfer data to another device (i.e. GPU, CPU:torch, CPU:numpy). + + This function recursively processes nested data structures (lists, tuples, dicts) + and transfers each tensor to the specified device. + + Args: + batch: Data to transfer (list, tuple, dict of tensors or other objects) + device: Target device - pytorch device (e.g., 'cuda', 'cpu') or 'numpy' + callback: Optional function that would be called on every element before processing + non_blocking: If True, allows asynchronous copy to GPU (may be faster) + + Returns: + Data with the same structure as input but with tensors transferred to target device + """ + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return { + k: to_device(v, device, non_blocking=non_blocking) for k, v in batch.items() + } + + if isinstance(batch, (tuple, list)): + return type(batch)( + to_device(x, device, non_blocking=non_blocking) for x in batch + ) + + x = batch + if device == "numpy": + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +def to_numpy(x): + """Convert data to numpy arrays. + + Args: + x: Input data (can be tensor, array, or nested structure) + + Returns: + Data with the same structure but with tensors converted to numpy arrays + """ + return to_device(x, "numpy") + + +def to_cpu(x): + """Transfer data to CPU. + + Args: + x: Input data (can be tensor, array, or nested structure) + + Returns: + Data with the same structure but with tensors moved to CPU + """ + return to_device(x, "cpu") + + +def to_cuda(x): + """Transfer data to CUDA device (GPU). + + Args: + x: Input data (can be tensor, array, or nested structure) + + Returns: + Data with the same structure but with tensors moved to GPU + """ + return to_device(x, "cuda") diff --git a/mapanything/utils/geometry.py b/mapanything/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..9452a13e7a9f97ef59cd4ae48f8d271bbac38792 --- /dev/null +++ b/mapanything/utils/geometry.py @@ -0,0 +1,2188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Utilities for geometry operations. + +References: DUSt3R, MoGe +""" + +from numbers import Number +from typing import Tuple, Union + +import einops as ein +import numpy as np +import torch +import torch.nn.functional as F + +from mapanything.utils.misc import invalid_to_zeros +from mapanything.utils.warnings import no_warnings + + +def depthmap_to_camera_frame(depthmap, intrinsics): + """ + Convert depth image to a pointcloud in camera frame. + + Args: + - depthmap: HxW or BxHxW torch tensor + - intrinsics: 3x3 or Bx3x3 torch tensor + + Returns: + pointmap in camera frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels. + """ + # Add batch dimension if not present + if depthmap.dim() == 2: + depthmap = depthmap.unsqueeze(0) + intrinsics = intrinsics.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + batch_size, height, width = depthmap.shape + device = depthmap.device + + # Compute 3D point in camera frame associated with each pixel + x_grid, y_grid = torch.meshgrid( + torch.arange(width, device=device).float(), + torch.arange(height, device=device).float(), + indexing="xy", + ) + x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) + y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) + + fx = intrinsics[:, 0, 0].view(-1, 1, 1) + fy = intrinsics[:, 1, 1].view(-1, 1, 1) + cx = intrinsics[:, 0, 2].view(-1, 1, 1) + cy = intrinsics[:, 1, 2].view(-1, 1, 1) + + depth_z = depthmap + xx = (x_grid - cx) * depth_z / fx + yy = (y_grid - cy) * depth_z / fy + pts3d_cam = torch.stack((xx, yy, depth_z), dim=-1) + + # Compute mask of valid non-zero depth pixels + valid_mask = depthmap > 0.0 + + # Remove batch dimension if it was added + if squeeze_batch_dim: + pts3d_cam = pts3d_cam.squeeze(0) + valid_mask = valid_mask.squeeze(0) + + return pts3d_cam, valid_mask + + +def depthmap_to_world_frame(depthmap, intrinsics, camera_pose=None): + """ + Convert depth image to a pointcloud in world frame. + + Args: + - depthmap: HxW or BxHxW torch tensor + - intrinsics: 3x3 or Bx3x3 torch tensor + - camera_pose: 4x4 or Bx4x4 torch tensor + + Returns: + pointmap in world frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels. + """ + pts3d_cam, valid_mask = depthmap_to_camera_frame(depthmap, intrinsics) + + if camera_pose is not None: + # Add batch dimension if not present + if camera_pose.dim() == 2: + camera_pose = camera_pose.unsqueeze(0) + pts3d_cam = pts3d_cam.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Convert points from camera frame to world frame + pts3d_cam_homo = torch.cat( + [pts3d_cam, torch.ones_like(pts3d_cam[..., :1])], dim=-1 + ) + pts3d_world = ein.einsum( + camera_pose, pts3d_cam_homo, "b i k, b h w k -> b h w i" + ) + pts3d_world = pts3d_world[..., :3] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + pts3d_world = pts3d_world.squeeze(0) + else: + pts3d_world = pts3d_cam + + return pts3d_world, valid_mask + + +def transform_pts3d(pts3d, transformation): + """ + Transform 3D points using a 4x4 transformation matrix. + + Args: + - pts3d: HxWx3 or BxHxWx3 torch tensor + - transformation: 4x4 or Bx4x4 torch tensor + + Returns: + transformed points (HxWx3 or BxHxWx3 tensor) + """ + # Add batch dimension if not present + if pts3d.dim() == 3: + pts3d = pts3d.unsqueeze(0) + transformation = transformation.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Convert points to homogeneous coordinates + pts3d_homo = torch.cat([pts3d, torch.ones_like(pts3d[..., :1])], dim=-1) + + # Transform points + transformed_pts3d = ein.einsum( + transformation, pts3d_homo, "b i k, b h w k -> b h w i" + ) + transformed_pts3d = transformed_pts3d[..., :3] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + transformed_pts3d = transformed_pts3d.squeeze(0) + + return transformed_pts3d + + +def project_pts3d_to_image(pts3d, intrinsics, return_z_dim): + """ + Project 3D points to image plane (assumes pinhole camera model with no distortion). + + Args: + - pts3d: HxWx3 or BxHxWx3 torch tensor + - intrinsics: 3x3 or Bx3x3 torch tensor + - return_z_dim: bool, whether to return the third dimension of the projected points + + Returns: + projected points (HxWx2) + """ + if pts3d.dim() == 3: + pts3d = pts3d.unsqueeze(0) + intrinsics = intrinsics.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Project points to image plane + projected_pts2d = ein.einsum(intrinsics, pts3d, "b i k, b h w k -> b h w i") + projected_pts2d[..., :2] /= projected_pts2d[..., 2].unsqueeze(-1).clamp(min=1e-6) + + # Remove the z dimension if not required + if not return_z_dim: + projected_pts2d = projected_pts2d[..., :2] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + projected_pts2d = projected_pts2d.squeeze(0) + + return projected_pts2d + + +def get_rays_in_camera_frame(intrinsics, height, width, normalize_to_unit_sphere): + """ + Convert camera intrinsics to a raymap (ray origins + directions) in camera frame. + Note: Currently only supports pinhole camera model. + + Args: + - intrinsics: 3x3 or Bx3x3 torch tensor + - height: int + - width: int + - normalize_to_unit_sphere: bool + + Returns: + - ray_origins: (HxWx3 or BxHxWx3) tensor + - ray_directions: (HxWx3 or BxHxWx3) tensor + """ + # Add batch dimension if not present + if intrinsics.dim() == 2: + intrinsics = intrinsics.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + batch_size = intrinsics.shape[0] + device = intrinsics.device + + # Compute rays in camera frame associated with each pixel + x_grid, y_grid = torch.meshgrid( + torch.arange(width, device=device).float(), + torch.arange(height, device=device).float(), + indexing="xy", + ) + x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) + y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) + + fx = intrinsics[:, 0, 0].view(-1, 1, 1) + fy = intrinsics[:, 1, 1].view(-1, 1, 1) + cx = intrinsics[:, 0, 2].view(-1, 1, 1) + cy = intrinsics[:, 1, 2].view(-1, 1, 1) + + ray_origins = torch.zeros((batch_size, height, width, 3), device=device) + xx = (x_grid - cx) / fx + yy = (y_grid - cy) / fy + ray_directions = torch.stack((xx, yy, torch.ones_like(xx)), dim=-1) + + # Normalize ray directions to unit sphere if required (else rays will lie on unit plane) + if normalize_to_unit_sphere: + ray_directions = ray_directions / torch.norm( + ray_directions, dim=-1, keepdim=True + ) + + # Remove batch dimension if it was added + if squeeze_batch_dim: + ray_origins = ray_origins.squeeze(0) + ray_directions = ray_directions.squeeze(0) + + return ray_origins, ray_directions + + +def get_rays_in_world_frame( + intrinsics, height, width, normalize_to_unit_sphere, camera_pose=None +): + """ + Convert camera intrinsics & camera_pose (if provided) to a raymap (ray origins + directions) in camera or world frame (if camera_pose is provided). + Note: Currently only supports pinhole camera model. + + Args: + - intrinsics: 3x3 or Bx3x3 torch tensor + - height: int + - width: int + - normalize_to_unit_sphere: bool + - camera_pose: 4x4 or Bx4x4 torch tensor + + Returns: + - ray_origins: (HxWx3 or BxHxWx3) tensor + - ray_directions: (HxWx3 or BxHxWx3) tensor + """ + # Get rays in camera frame + ray_origins, ray_directions = get_rays_in_camera_frame( + intrinsics, height, width, normalize_to_unit_sphere + ) + + if camera_pose is not None: + # Add batch dimension if not present + if camera_pose.dim() == 2: + camera_pose = camera_pose.unsqueeze(0) + ray_origins = ray_origins.unsqueeze(0) + ray_directions = ray_directions.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Convert rays from camera frame to world frame + ray_origins_homo = torch.cat( + [ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1 + ) + ray_directions_homo = torch.cat( + [ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1 + ) + ray_origins_world = ein.einsum( + camera_pose, ray_origins_homo, "b i k, b h w k -> b h w i" + ) + ray_directions_world = ein.einsum( + camera_pose, ray_directions_homo, "b i k, b h w k -> b h w i" + ) + ray_origins_world = ray_origins_world[..., :3] + ray_directions_world = ray_directions_world[..., :3] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + ray_origins_world = ray_origins_world.squeeze(0) + ray_directions_world = ray_directions_world.squeeze(0) + else: + ray_origins_world = ray_origins + ray_directions_world = ray_directions + + return ray_origins_world, ray_directions_world + + +def recover_pinhole_intrinsics_from_ray_directions( + ray_directions, use_geometric_calculation=False +): + """ + Recover pinhole camera intrinsics from ray directions, supporting both batched and non-batched inputs. + + Args: + ray_directions: Tensor of shape [H, W, 3] or [B, H, W, 3] containing unit normalized ray directions + + Returns: + Dictionary containing camera intrinsics (fx, fy, cx, cy) as tensors + """ + # Add batch dimension if not present + if ray_directions.dim() == 3: # [H, W, 3] + ray_directions = ray_directions.unsqueeze(0) # [1, H, W, 3] + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + batch_size, height, width, _ = ray_directions.shape + device = ray_directions.device + + # Create pixel coordinate grid + x_grid, y_grid = torch.meshgrid( + torch.arange(width, device=device).float(), + torch.arange(height, device=device).float(), + indexing="xy", + ) + + # Expand grid for all batches + x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W] + y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W] + + # Determine if high resolution or not + is_high_res = height * width > 1000000 + + if is_high_res or use_geometric_calculation: + # For high-resolution cases, use direct geometric calculation + # Define key points + center_h, center_w = height // 2, width // 2 + quarter_w, three_quarter_w = width // 4, 3 * width // 4 + quarter_h, three_quarter_h = height // 4, 3 * height // 4 + + # Get rays at key points + center_rays = ray_directions[:, center_h, center_w, :].clone() # [B, 3] + left_rays = ray_directions[:, center_h, quarter_w, :].clone() # [B, 3] + right_rays = ray_directions[:, center_h, three_quarter_w, :].clone() # [B, 3] + top_rays = ray_directions[:, quarter_h, center_w, :].clone() # [B, 3] + bottom_rays = ray_directions[:, three_quarter_h, center_w, :].clone() # [B, 3] + + # Normalize rays to have dz = 1 + center_rays = center_rays / center_rays[:, 2].unsqueeze(1) # [B, 3] + left_rays = left_rays / left_rays[:, 2].unsqueeze(1) # [B, 3] + right_rays = right_rays / right_rays[:, 2].unsqueeze(1) # [B, 3] + top_rays = top_rays / top_rays[:, 2].unsqueeze(1) # [B, 3] + bottom_rays = bottom_rays / bottom_rays[:, 2].unsqueeze(1) # [B, 3] + + # Calculate fx directly (vectorized across batch) + fx_left = (quarter_w - center_w) / (left_rays[:, 0] - center_rays[:, 0]) + fx_right = (three_quarter_w - center_w) / (right_rays[:, 0] - center_rays[:, 0]) + fx = (fx_left + fx_right) / 2 # Average for robustness + + # Calculate cx + cx = center_w - fx * center_rays[:, 0] + + # Calculate fy and cy + fy_top = (quarter_h - center_h) / (top_rays[:, 1] - center_rays[:, 1]) + fy_bottom = (three_quarter_h - center_h) / ( + bottom_rays[:, 1] - center_rays[:, 1] + ) + fy = (fy_top + fy_bottom) / 2 + + cy = center_h - fy * center_rays[:, 1] + else: + # For standard resolution, use regression with sampling for efficiency + # Sample a grid of points (but more dense than for high-res) + step_h = max(1, height // 50) + step_w = max(1, width // 50) + + h_indices = torch.arange(0, height, step_h, device=device) + w_indices = torch.arange(0, width, step_w, device=device) + + # Extract subset of coordinates + x_sampled = x_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W'] + y_sampled = y_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W'] + rays_sampled = ray_directions[ + :, h_indices[:, None], w_indices[None, :], : + ] # [B, H', W', 3] + + # Reshape for linear regression + x_flat = x_sampled.reshape(batch_size, -1) # [B, N] + y_flat = y_sampled.reshape(batch_size, -1) # [B, N] + + # Extract ray direction components + dx = rays_sampled[..., 0].reshape(batch_size, -1) # [B, N] + dy = rays_sampled[..., 1].reshape(batch_size, -1) # [B, N] + dz = rays_sampled[..., 2].reshape(batch_size, -1) # [B, N] + + # Compute ratios for linear regression + ratio_x = dx / dz # [B, N] + ratio_y = dy / dz # [B, N] + + # Since torch.linalg.lstsq doesn't support batched input, we'll use a different approach + # For x-direction: x = cx + fx * (dx/dz) + # We can solve this using normal equations: A^T A x = A^T b + # Create design matrices + ones = torch.ones_like(x_flat) # [B, N] + A_x = torch.stack([ones, ratio_x], dim=2) # [B, N, 2] + b_x = x_flat.unsqueeze(2) # [B, N, 1] + + # Compute A^T A and A^T b for each batch + ATA_x = torch.bmm(A_x.transpose(1, 2), A_x) # [B, 2, 2] + ATb_x = torch.bmm(A_x.transpose(1, 2), b_x) # [B, 2, 1] + + # Solve the system for each batch + solution_x = torch.linalg.solve(ATA_x, ATb_x).squeeze(2) # [B, 2] + cx, fx = solution_x[:, 0], solution_x[:, 1] + + # Repeat for y-direction + A_y = torch.stack([ones, ratio_y], dim=2) # [B, N, 2] + b_y = y_flat.unsqueeze(2) # [B, N, 1] + + ATA_y = torch.bmm(A_y.transpose(1, 2), A_y) # [B, 2, 2] + ATb_y = torch.bmm(A_y.transpose(1, 2), b_y) # [B, 2, 1] + + solution_y = torch.linalg.solve(ATA_y, ATb_y).squeeze(2) # [B, 2] + cy, fy = solution_y[:, 0], solution_y[:, 1] + + # Create intrinsics matrices + batch_size = fx.shape[0] + intrinsics = torch.zeros(batch_size, 3, 3, device=ray_directions.device) + + # Fill in the intrinsics matrices + intrinsics[:, 0, 0] = fx # focal length x + intrinsics[:, 1, 1] = fy # focal length y + intrinsics[:, 0, 2] = cx # principal point x + intrinsics[:, 1, 2] = cy # principal point y + intrinsics[:, 2, 2] = 1.0 # bottom-right element is always 1 + + # Remove batch dimension if it was added + if squeeze_batch_dim: + intrinsics = intrinsics.squeeze(0) + + return intrinsics + + +def transform_rays(ray_origins, ray_directions, transformation): + """ + Transform 6D rays (ray origins and ray directions) using a 4x4 transformation matrix. + + Args: + - ray_origins: HxWx3 or BxHxWx3 torch tensor + - ray_directions: HxWx3 or BxHxWx3 torch tensor + - transformation: 4x4 or Bx4x4 torch tensor + - normalize_to_unit_sphere: bool, whether to normalize the transformed ray directions to unit length + + Returns: + transformed ray_origins (HxWx3 or BxHxWx3 tensor) and ray_directions (HxWx3 or BxHxWx3 tensor) + """ + # Add batch dimension if not present + if ray_origins.dim() == 3: + ray_origins = ray_origins.unsqueeze(0) + ray_directions = ray_directions.unsqueeze(0) + transformation = transformation.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Transform ray origins and directions + ray_origins_homo = torch.cat( + [ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1 + ) + ray_directions_homo = torch.cat( + [ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1 + ) + transformed_ray_origins = ein.einsum( + transformation, ray_origins_homo, "b i k, b h w k -> b h w i" + ) + transformed_ray_directions = ein.einsum( + transformation, ray_directions_homo, "b i k, b h w k -> b h w i" + ) + transformed_ray_origins = transformed_ray_origins[..., :3] + transformed_ray_directions = transformed_ray_directions[..., :3] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + transformed_ray_origins = transformed_ray_origins.squeeze(0) + transformed_ray_directions = transformed_ray_directions.squeeze(0) + + return transformed_ray_origins, transformed_ray_directions + + +def convert_z_depth_to_depth_along_ray(z_depth, intrinsics): + """ + Convert z-depth image to depth along camera rays. + + Args: + - z_depth: HxW or BxHxW torch tensor + - intrinsics: 3x3 or Bx3x3 torch tensor + + Returns: + - depth_along_ray: HxW or BxHxW torch tensor + """ + # Add batch dimension if not present + if z_depth.dim() == 2: + z_depth = z_depth.unsqueeze(0) + intrinsics = intrinsics.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Get rays in camera frame + batch_size, height, width = z_depth.shape + _, ray_directions = get_rays_in_camera_frame( + intrinsics, height, width, normalize_to_unit_sphere=False + ) + + # Compute depth along ray + pts3d_cam = z_depth[..., None] * ray_directions + depth_along_ray = torch.norm(pts3d_cam, dim=-1) + + # Remove batch dimension if it was added + if squeeze_batch_dim: + depth_along_ray = depth_along_ray.squeeze(0) + + return depth_along_ray + + +def convert_raymap_z_depth_quats_to_pointmap(ray_origins, ray_directions, depth, quats): + """ + Convert raymap (ray origins + directions on unit plane), z-depth and + unit quaternions (representing rotation) to a pointmap in world frame. + + Args: + - ray_origins: (HxWx3 or BxHxWx3) torch tensor + - ray_directions: (HxWx3 or BxHxWx3) torch tensor + - depth: (HxWx1 or BxHxWx1) torch tensor + - quats: (HxWx4 or BxHxWx4) torch tensor (unit quaternions and notation is (x, y, z, w)) + + Returns: + - pointmap: (HxWx3 or BxHxWx3) torch tensor + """ + # Add batch dimension if not present + if ray_origins.dim() == 3: + ray_origins = ray_origins.unsqueeze(0) + ray_directions = ray_directions.unsqueeze(0) + depth = depth.unsqueeze(0) + quats = quats.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + batch_size, height, width, _ = depth.shape + device = depth.device + + # Normalize the quaternions to ensure they are unit quaternions + quats = quats / torch.norm(quats, dim=-1, keepdim=True) + + # Convert quaternions to pixel-wise rotation matrices + qx, qy, qz, qw = quats[..., 0], quats[..., 1], quats[..., 2], quats[..., 3] + rot_mat = ( + torch.stack( + [ + qw**2 + qx**2 - qy**2 - qz**2, + 2 * (qx * qy - qw * qz), + 2 * (qw * qy + qx * qz), + 2 * (qw * qz + qx * qy), + qw**2 - qx**2 + qy**2 - qz**2, + 2 * (qy * qz - qw * qx), + 2 * (qx * qz - qw * qy), + 2 * (qw * qx + qy * qz), + qw**2 - qx**2 - qy**2 + qz**2, + ], + dim=-1, + ) + .reshape(batch_size, height, width, 3, 3) + .to(device) + ) + + # Compute 3D points in local camera frame + pts3d_local = depth * ray_directions + + # Rotate the local points using the quaternions + rotated_pts3d_local = ein.einsum( + rot_mat, pts3d_local, "b h w i k, b h w k -> b h w i" + ) + + # Compute 3D point in world frame associated with each pixel + pts3d = ray_origins + rotated_pts3d_local + + # Remove batch dimension if it was added + if squeeze_batch_dim: + pts3d = pts3d.squeeze(0) + + return pts3d + + +def quaternion_to_rotation_matrix(quat): + """ + Convert a quaternion into a 3x3 rotation matrix. + + Args: + - quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + + Returns: + - rot_matrix: 3x3 or Bx3x3 torch tensor + """ + if quat.dim() == 1: + quat = quat.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Ensure the quaternion is normalized + quat = quat / quat.norm(dim=1, keepdim=True) + x, y, z, w = quat.unbind(dim=1) + + # Compute the rotation matrix elements + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + + # Construct the rotation matrix + rot_matrix = torch.stack( + [ + 1 - 2 * (yy + zz), + 2 * (xy - wz), + 2 * (xz + wy), + 2 * (xy + wz), + 1 - 2 * (xx + zz), + 2 * (yz - wx), + 2 * (xz - wy), + 2 * (yz + wx), + 1 - 2 * (xx + yy), + ], + dim=1, + ).view(-1, 3, 3) + + # Squeeze batch dimension if it was unsqueezed + if squeeze_batch_dim: + rot_matrix = rot_matrix.squeeze(0) + + return rot_matrix + + +def rotation_matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def quaternion_inverse(quat): + """ + Compute the inverse of a quaternion. + + Args: + - quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + + Returns: + - inv_quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + """ + # Unsqueeze batch dimension if not present + if quat.dim() == 1: + quat = quat.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Compute the inverse + quat_conj = quat.clone() + quat_conj[:, :3] = -quat_conj[:, :3] + quat_norm = torch.sum(quat * quat, dim=1, keepdim=True) + inv_quat = quat_conj / quat_norm + + # Squeeze batch dimension if it was unsqueezed + if squeeze_batch_dim: + inv_quat = inv_quat.squeeze(0) + + return inv_quat + + +def quaternion_multiply(q1, q2): + """ + Multiply two quaternions. + + Args: + - q1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + - q2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + + Returns: + - qm: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + """ + # Unsqueeze batch dimension if not present + if q1.dim() == 1: + q1 = q1.unsqueeze(0) + q2 = q2.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Unbind the quaternions + x1, y1, z1, w1 = q1.unbind(dim=1) + x2, y2, z2, w2 = q2.unbind(dim=1) + + # Compute the product + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + + # Stack the components + qm = torch.stack([x, y, z, w], dim=1) + + # Squeeze batch dimension if it was unsqueezed + if squeeze_batch_dim: + qm = qm.squeeze(0) + + return qm + + +def transform_pose_using_quats_and_trans_2_to_1(quats1, trans1, quats2, trans2): + """ + Transform quats and translation of pose2 from absolute frame (pose2 to world) to relative frame (pose2 to pose1). + + Args: + - quats1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + - trans1: 3 or Bx3 torch tensor + - quats2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + - trans2: 3 or Bx3 torch tensor + + Returns: + - quats: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) + - trans: 3 or Bx3 torch tensor + """ + # Unsqueeze batch dimension if not present + if quats1.dim() == 1: + quats1 = quats1.unsqueeze(0) + trans1 = trans1.unsqueeze(0) + quats2 = quats2.unsqueeze(0) + trans2 = trans2.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + # Compute the inverse of view1's pose + inv_quats1 = quaternion_inverse(quats1) + R1_inv = quaternion_to_rotation_matrix(inv_quats1) + t1_inv = -1 * ein.einsum(R1_inv, trans1, "b i j, b j -> b i") + + # Transform view2's pose to view1's frame + quats = quaternion_multiply(inv_quats1, quats2) + trans = ein.einsum(R1_inv, trans2, "b i j, b j -> b i") + t1_inv + + # Squeeze batch dimension if it was unsqueezed + if squeeze_batch_dim: + quats = quats.squeeze(0) + trans = trans.squeeze(0) + + return quats, trans + + +def convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( + ray_directions, depth_along_ray, pose_trans, pose_quats +): + """ + Convert ray directions, depth along ray, pose translation, and + unit quaternions (representing pose rotation) to a pointmap in world frame. + + Args: + - ray_directions: (HxWx3 or BxHxWx3) torch tensor + - depth_along_ray: (HxWx1 or BxHxWx1) torch tensor + - pose_trans: (3 or Bx3) torch tensor + - pose_quats: (4 or Bx4) torch tensor (unit quaternions and notation is (x, y, z, w)) + + Returns: + - pointmap: (HxWx3 or BxHxWx3) torch tensor + """ + # Add batch dimension if not present + if ray_directions.dim() == 3: + ray_directions = ray_directions.unsqueeze(0) + depth_along_ray = depth_along_ray.unsqueeze(0) + pose_trans = pose_trans.unsqueeze(0) + pose_quats = pose_quats.unsqueeze(0) + squeeze_batch_dim = True + else: + squeeze_batch_dim = False + + batch_size, height, width, _ = depth_along_ray.shape + device = depth_along_ray.device + + # Normalize the quaternions to ensure they are unit quaternions + pose_quats = pose_quats / torch.norm(pose_quats, dim=-1, keepdim=True) + + # Convert quaternions to rotation matrices (B x 3 x 3) + rot_mat = quaternion_to_rotation_matrix(pose_quats) + + # Get pose matrix (B x 4 x 4) + pose_mat = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1) + pose_mat[:, :3, :3] = rot_mat + pose_mat[:, :3, 3] = pose_trans + + # Compute 3D points in local camera frame + pts3d_local = depth_along_ray * ray_directions + + # Compute 3D points in world frame + pts3d_homo = torch.cat([pts3d_local, torch.ones_like(pts3d_local[..., :1])], dim=-1) + pts3d_world = ein.einsum(pose_mat, pts3d_homo, "b i k, b h w k -> b h w i") + pts3d_world = pts3d_world[..., :3] + + # Remove batch dimension if it was added + if squeeze_batch_dim: + pts3d_world = pts3d_world.squeeze(0) + + return pts3d_world + + +def xy_grid( + W, + H, + device=None, + origin=(0, 0), + unsqueeze=None, + cat_dim=-1, + homogeneous=False, + **arange_kw, +): + """ + Generate a coordinate grid of shape (H,W,2) or (H,W,3) if homogeneous=True. + + Args: + W (int): Width of the grid + H (int): Height of the grid + device (torch.device, optional): Device to place the grid on. If None, uses numpy arrays + origin (tuple, optional): Origin coordinates (x,y) for the grid. Default is (0,0) + unsqueeze (int, optional): Dimension to unsqueeze in the output tensors + cat_dim (int, optional): Dimension to concatenate the x,y coordinates. If None, returns tuple + homogeneous (bool, optional): If True, adds a third dimension of ones to make homogeneous coordinates + **arange_kw: Additional keyword arguments passed to np.arange or torch.arange + + Returns: + numpy.ndarray or torch.Tensor: Coordinate grid where: + - output[j,i,0] = i + origin[0] (x-coordinate) + - output[j,i,1] = j + origin[1] (y-coordinate) + - output[j,i,2] = 1 (if homogeneous=True) + """ + if device is None: + # numpy + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + # torch + def arange(*a, **kw): + return torch.arange(*a, device=device, **kw) + + meshgrid, stack = torch.meshgrid, torch.stack + + def ones(*a): + return torch.ones(*a, device=device) + + tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing="xy") + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """ + Apply a geometric transformation to a set of 3-D points. + + Args: + Trf: 3x3 or 4x4 projection matrix (typically a Homography) or batch of matrices + with shape (B, 3, 3) or (B, 4, 4) + pts: numpy/torch/tuple of coordinates with shape (..., 2) or (..., 3) + ncol: int, number of columns of the result (2 or 3) + norm: float, if not 0, the result is projected on the z=norm plane + (homogeneous normalization) + + Returns: + Array or tensor of projected points with the same type as input and shape (..., ncol) + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # Adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # Optimized code + if ( + isinstance(Trf, torch.Tensor) + and isinstance(pts, torch.Tensor) + and Trf.ndim == 3 + and pts.ndim == 4 + ): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = ( + torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + + Trf[:, None, None, :d, d] + ) + else: + raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /=, it will lead to a bug + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + + return res + + +def inv(mat): + """ + Invert a torch or numpy matrix + """ + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f"bad matrix type = {type(mat)}") + + +def closed_form_pose_inverse( + pose_matrices, rotation_matrices=None, translation_vectors=None +): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 pose matrices in a batch. + + If `rotation_matrices` and `translation_vectors` are provided, they must correspond to the rotation and translation + components of `pose_matrices`. Otherwise, they will be extracted from `pose_matrices`. + + Args: + pose_matrices: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + rotation_matrices (optional): Nx3x3 array or tensor of rotation matrices. + translation_vectors (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as input `pose_matrices`. + + Shapes: + pose_matrices: (N, 4, 4) + rotation_matrices: (N, 3, 3) + translation_vectors: (N, 3, 1) + """ + # Check if pose_matrices is a numpy array or a torch tensor + is_numpy = isinstance(pose_matrices, np.ndarray) + + # Validate shapes + if pose_matrices.shape[-2:] != (4, 4) and pose_matrices.shape[-2:] != (3, 4): + raise ValueError( + f"pose_matrices must be of shape (N,4,4), got {pose_matrices.shape}." + ) + + # Extract rotation_matrices and translation_vectors if not provided + if rotation_matrices is None: + rotation_matrices = pose_matrices[:, :3, :3] + if translation_vectors is None: + translation_vectors = pose_matrices[:, :3, 3:] + + # Compute the inverse of input SE3 matrices + if is_numpy: + rotation_transposed = np.transpose(rotation_matrices, (0, 2, 1)) + new_translation = -np.matmul(rotation_transposed, translation_vectors) + inverted_matrix = np.tile(np.eye(4), (len(rotation_matrices), 1, 1)) + else: + rotation_transposed = rotation_matrices.transpose(1, 2) + new_translation = -torch.bmm(rotation_transposed, translation_vectors) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(rotation_matrices), 1, 1) + inverted_matrix = inverted_matrix.to(rotation_matrices.dtype).to( + rotation_matrices.device + ) + inverted_matrix[:, :3, :3] = rotation_transposed + inverted_matrix[:, :3, 3:] = new_translation + + return inverted_matrix + + +def relative_pose_transformation(trans_01, trans_02): + r""" + Function that computes the relative homogenous transformation from a + reference transformation :math:`T_1^{0} = \begin{bmatrix} R_1 & t_1 \\ + \mathbf{0} & 1 \end{bmatrix}` to destination :math:`T_2^{0} = + \begin{bmatrix} R_2 & t_2 \\ \mathbf{0} & 1 \end{bmatrix}`. + + The relative transformation is computed as follows: + + .. math:: + + T_1^{2} = (T_0^{1})^{-1} \cdot T_0^{2} + + Arguments: + trans_01 (torch.Tensor): reference transformation tensor of shape + :math:`(N, 4, 4)` or :math:`(4, 4)`. + trans_02 (torch.Tensor): destination transformation tensor of shape + :math:`(N, 4, 4)` or :math:`(4, 4)`. + + Shape: + - Output: :math:`(N, 4, 4)` or :math:`(4, 4)`. + + Returns: + torch.Tensor: the relative transformation between the transformations. + + Example:: + >>> trans_01 = torch.eye(4) # 4x4 + >>> trans_02 = torch.eye(4) # 4x4 + >>> trans_12 = relative_transformation(trans_01, trans_02) # 4x4 + """ + if not torch.is_tensor(trans_01): + raise TypeError( + "Input trans_01 type is not a torch.Tensor. Got {}".format(type(trans_01)) + ) + if not torch.is_tensor(trans_02): + raise TypeError( + "Input trans_02 type is not a torch.Tensor. Got {}".format(type(trans_02)) + ) + if trans_01.dim() not in (2, 3) and trans_01.shape[-2:] == (4, 4): + raise ValueError( + "Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_01.shape) + ) + if trans_02.dim() not in (2, 3) and trans_02.shape[-2:] == (4, 4): + raise ValueError( + "Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_02.shape) + ) + if not trans_01.dim() == trans_02.dim(): + raise ValueError( + "Input number of dims must match. Got {} and {}".format( + trans_01.dim(), trans_02.dim() + ) + ) + + # Convert to Nx4x4 if inputs are 4x4 + squeeze_batch_dim = False + if trans_01.dim() == 2: + trans_01 = trans_01.unsqueeze(0) + trans_02 = trans_02.unsqueeze(0) + squeeze_batch_dim = True + + # Compute inverse of trans_01 using closed form + trans_10 = closed_form_pose_inverse(trans_01) + + # Compose transformations using matrix multiplication + trans_12 = torch.matmul(trans_10, trans_02) + + # Remove batch dimension if it was added + if squeeze_batch_dim: + trans_12 = trans_12.squeeze(0) + + return trans_12 + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: # [B,H,W] + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + # set principal point + if pp is None: + grid_x = grid_x - (W - 1) / 2 + grid_y = grid_y - (H - 1) / 2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + # Mask for valid coordinates + valid_mask = depthmap > 0.0 + + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates( + depthmap, camera_intrinsics, camera_pose, **kw +): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + X_world = X_cam # default + if camera_pose is not None: + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + X_world = ( + np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + ) + + return X_world, valid_mask + + +def get_absolute_pointmaps_and_rays_info( + depthmap, camera_intrinsics, camera_pose, **kw +): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), + a mask specifying valid pixels, + ray origins of absolute coordinates (HxWx3 array), + ray directions of absolute coordinates (HxWx3 array), + depth along ray (HxWx1 array), + ray directions of camera/local coordinates (HxWx3 array), + pointmap of camera/local coordinates (HxWx3 array). + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: pinhole & there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + # Get the rays on the unit plane + u, v = np.meshgrid(np.arange(W), np.arange(H)) + x_cam = (u - cu) / fu + y_cam = (v - cv) / fv + z_cam = np.ones_like(x_cam) + ray_dirs_cam_on_unit_plane = np.stack((x_cam, y_cam, z_cam), axis=-1).astype( + np.float32 + ) + + # Compute the 3d points in the local camera coordinate system + pts_cam = depthmap[..., None] * ray_dirs_cam_on_unit_plane + + # Get the depth along the ray and compute the ray directions on the unit sphere + depth_along_ray = np.linalg.norm(pts_cam, axis=-1, keepdims=True) + ray_directions_cam = ray_dirs_cam_on_unit_plane / np.linalg.norm( + ray_dirs_cam_on_unit_plane, axis=-1, keepdims=True + ) + + # Mask for valid coordinates + valid_mask = depthmap > 0.0 + + # Get the ray origins in absolute coordinates and the ray directions in absolute coordinates + ray_origins_world = np.zeros_like(ray_directions_cam) + ray_directions_world = ray_directions_cam + pts_world = pts_cam + if camera_pose is not None: + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates + ray_origins_world = ray_origins_world + t_cam2world[None, None, :] + ray_directions_world = np.einsum( + "ik, vuk -> vui", R_cam2world, ray_directions_cam + ) + pts_world = ray_origins_world + ray_directions_world * depth_along_ray + + return ( + pts_world, + valid_mask, + ray_origins_world, + ray_directions_world, + depth_along_ray, + ray_directions_cam, + pts_cam, + ) + + +def adjust_camera_params_for_rotation(camera_params, original_size, k): + """ + Adjust camera parameters for rotation. + + Args: + camera_params: Camera parameters [fx, fy, cx, cy, ...] + original_size: Original image size as (width, height) + k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise) + + Returns: + Adjusted camera parameters + """ + fx, fy, cx, cy = camera_params[:4] + width, height = original_size + + if k % 4 == 1: # 90 degrees counter-clockwise + new_fx, new_fy = fy, fx + new_cx, new_cy = height - cy, cx + elif k % 4 == 2: # 180 degrees + new_fx, new_fy = fx, fy + new_cx, new_cy = width - cx, height - cy + elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise) + new_fx, new_fy = fy, fx + new_cx, new_cy = cy, width - cx + else: # No rotation + return camera_params + + adjusted_params = [new_fx, new_fy, new_cx, new_cy] + if len(camera_params) > 4: + adjusted_params.extend(camera_params[4:]) + + return adjusted_params + + +def adjust_pose_for_rotation(pose, k): + """ + Adjust camera pose for rotation. + + Args: + pose: 4x4 camera pose matrix (camera-to-world, OpenCV convention - X right, Y down, Z forward) + k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise) + + Returns: + Adjusted 4x4 camera pose matrix + """ + # Create rotation matrices for different rotations + if k % 4 == 1: # 90 degrees counter-clockwise + rot_transform = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) + elif k % 4 == 2: # 180 degrees + rot_transform = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]) + elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise) + rot_transform = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + else: # No rotation + return pose + + # Apply the transformation to the pose + adjusted_pose = pose + adjusted_pose[:3, :3] = adjusted_pose[:3, :3] @ rot_transform.T + + return adjusted_pose + + +def crop_to_aspect_ratio(image, depth, camera_params, target_ratio=1.5): + """ + Crop image and depth to the largest possible target aspect ratio while + keeping the left side if aspect ratio is wider and the bottom of image if the aspect ratio is taller. + + Args: + image: PIL image + depth: Depth map as numpy array + camera_params: Camera parameters [fx, fy, cx, cy, ...] + target_ratio: Target width/height ratio + + Returns: + Cropped image, cropped depth, adjusted camera parameters + """ + width, height = image.size + fx, fy, cx, cy = camera_params[:4] + current_ratio = width / height + + if abs(current_ratio - target_ratio) < 1e-6: + # Already at target ratio + return image, depth, camera_params + + if current_ratio > target_ratio: + # Image is wider than target ratio, crop width + new_width = int(height * target_ratio) + left = 0 + right = new_width + + # Crop image + cropped_image = image.crop((left, 0, right, height)) + + # Crop depth + if len(depth.shape) == 3: + cropped_depth = depth[:, left:right, :] + else: + cropped_depth = depth[:, left:right] + + # Adjust camera parameters + new_cx = cx - left + adjusted_params = [fx, fy, new_cx, cy] + list(camera_params[4:]) + + else: + # Image is taller than target ratio, crop height + new_height = int(width / target_ratio) + top = max(0, height - new_height) + bottom = height + + # Crop image + cropped_image = image.crop((0, top, width, bottom)) + + # Crop depth + if len(depth.shape) == 3: + cropped_depth = depth[top:bottom, :, :] + else: + cropped_depth = depth[top:bottom, :] + + # Adjust camera parameters + new_cy = cy - top + adjusted_params = [fx, fy, cx, new_cy] + list(camera_params[4:]) + + return cropped_image, cropped_depth, adjusted_params + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + + return K + + +def normalize_depth_using_non_zero_pixels(depth, return_norm_factor=False): + """ + Normalize the depth by the average depth of non-zero depth pixels. + + Args: + depth (torch.Tensor): Depth tensor of size [B, H, W, 1]. + Returns: + normalized_depth (torch.Tensor): Normalized depth tensor. + norm_factor (torch.Tensor): Norm factor tensor of size B. + """ + assert depth.ndim == 4 and depth.shape[3] == 1 + # Calculate the sum and count of non-zero depth pixels for each batch + valid_depth_mask = depth > 0 + valid_sum = torch.sum(depth * valid_depth_mask, dim=(1, 2, 3)) + valid_count = torch.sum(valid_depth_mask, dim=(1, 2, 3)) + + # Calculate the norm factor + norm_factor = valid_sum / (valid_count + 1e-8) + while norm_factor.ndim < depth.ndim: + norm_factor.unsqueeze_(-1) + + # Normalize the depth by the norm factor + norm_factor = norm_factor.clip(min=1e-8) + normalized_depth = depth / norm_factor + + # Create the output tuple + output = ( + (normalized_depth, norm_factor.squeeze(-1).squeeze(-1).squeeze(-1)) + if return_norm_factor + else normalized_depth + ) + + return output + + +def normalize_pose_translations(pose_translations, return_norm_factor=False): + """ + Normalize the pose translations by the average norm of the non-zero pose translations. + + Args: + pose_translations (torch.Tensor): Pose translations tensor of size [B, V, 3]. B is the batch size, V is the number of views. + Returns: + normalized_pose_translations (torch.Tensor): Normalized pose translations tensor of size [B, V, 3]. + norm_factor (torch.Tensor): Norm factor tensor of size B. + """ + assert pose_translations.ndim == 3 and pose_translations.shape[2] == 3 + # Compute distance of all pose translations to origin + pose_translations_dis = pose_translations.norm(dim=-1) # [B, V] + non_zero_pose_translations_dis = pose_translations_dis > 0 # [B, V] + + # Calculate the average norm of the translations across all views (considering only views with non-zero translations) + sum_of_all_views_pose_translations = pose_translations_dis.sum(dim=1) # [B] + count_of_all_views_with_non_zero_pose_translations = ( + non_zero_pose_translations_dis.sum(dim=1) + ) # [B] + norm_factor = sum_of_all_views_pose_translations / ( + count_of_all_views_with_non_zero_pose_translations + 1e-8 + ) # [B] + + # Normalize the pose translations by the norm factor + norm_factor = norm_factor.clip(min=1e-8) + normalized_pose_translations = pose_translations / norm_factor.unsqueeze( + -1 + ).unsqueeze(-1) + + # Create the output tuple + output = ( + (normalized_pose_translations, norm_factor) + if return_norm_factor + else normalized_pose_translations + ) + + return output + + +def normalize_multiple_pointclouds( + pts_list, valid_masks=None, norm_mode="avg_dis", ret_factor=False +): + """ + Normalize multiple point clouds using a joint normalization strategy. + + Args: + pts_list: List of point clouds, each with shape (..., H, W, 3) or (B, H, W, 3) + valid_masks: Optional list of masks indicating valid points in each point cloud + norm_mode: String in format "{norm}_{dis}" where: + - norm: Normalization strategy (currently only "avg" is supported) + - dis: Distance transformation ("dis" for raw distance, "log1p" for log(1+distance), + "warp-log1p" to warp points using log distance) + ret_factor: If True, return the normalization factor as the last element in the result list + + Returns: + List of normalized point clouds with the same shapes as inputs. + If ret_factor is True, the last element is the normalization factor. + """ + assert all(pts.ndim >= 3 and pts.shape[-1] == 3 for pts in pts_list) + if valid_masks is not None: + assert len(pts_list) == len(valid_masks) + + norm_mode, dis_mode = norm_mode.split("_") + + # Gather all points together (joint normalization) + nan_pts_list = [ + invalid_to_zeros(pts, valid_masks[i], ndim=3) + if valid_masks + else invalid_to_zeros(pts, None, ndim=3) + for i, pts in enumerate(pts_list) + ] + all_pts = torch.cat([nan_pts for nan_pts, _ in nan_pts_list], dim=1) + nnz_list = [nnz for _, nnz in nan_pts_list] + + # Compute distance to origin + all_dis = all_pts.norm(dim=-1) + if dis_mode == "dis": + pass # do nothing + elif dis_mode == "log1p": + all_dis = torch.log1p(all_dis) + elif dis_mode == "warp-log1p": + # Warp input points before normalizing them + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + for i, pts in enumerate(pts_list): + H, W = pts.shape[1:-1] + pts_list[i] = pts * warp_factor[:, i * (H * W) : (i + 1) * (H * W)].view( + -1, H, W, 1 + ) + all_dis = log_dis + else: + raise ValueError(f"bad {dis_mode=}") + + # Compute normalization factor + norm_factor = all_dis.sum(dim=1) / (sum(nnz_list) + 1e-8) + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts_list[0].ndim: + norm_factor.unsqueeze_(-1) + + # Normalize points + res = [pts / norm_factor for pts in pts_list] + if ret_factor: + res.append(norm_factor) + + return res + + +def apply_log_to_norm(input_data): + """ + Normalize the input data and apply a logarithmic transformation based on the normalization factor. + + Args: + input_data (torch.Tensor): The input tensor to be normalized and transformed. + + Returns: + torch.Tensor: The transformed tensor after normalization and logarithmic scaling. + """ + org_d = input_data.norm(dim=-1, keepdim=True) + input_data = input_data / org_d.clip(min=1e-8) + input_data = input_data * torch.log1p(org_d) + return input_data + + +def angle_diff_vec3(v1, v2, eps=1e-12): + """ + Compute angle difference between 3D vectors. + + Args: + v1: torch.Tensor of shape (..., 3) + v2: torch.Tensor of shape (..., 3) + eps: Small epsilon value for numerical stability + + Returns: + torch.Tensor: Angle differences in radians + """ + cross_norm = torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps + dot_prod = (v1 * v2).sum(dim=-1) + return torch.atan2(cross_norm, dot_prod) + + +def angle_diff_vec3_numpy(v1: np.ndarray, v2: np.ndarray, eps: float = 1e-12): + """ + Compute angle difference between 3D vectors using NumPy. + + Args: + v1 (np.ndarray): First vector of shape (..., 3) + v2 (np.ndarray): Second vector of shape (..., 3) + eps (float, optional): Small epsilon value for numerical stability. Defaults to 1e-12. + + Returns: + np.ndarray: Angle differences in radians + """ + return np.arctan2( + np.linalg.norm(np.cross(v1, v2, axis=-1), axis=-1) + eps, (v1 * v2).sum(axis=-1) + ) + + +@no_warnings(category=RuntimeWarning) +def points_to_normals( + point: np.ndarray, mask: np.ndarray = None, edge_threshold: float = None +) -> np.ndarray: + """ + Calculate normal map from point map. Value range is [-1, 1]. + + Args: + point (np.ndarray): shape (height, width, 3), point map + mask (optional, np.ndarray): shape (height, width), dtype=bool. Mask of valid depth pixels. Defaults to None. + edge_threshold (optional, float): threshold for the angle (in degrees) between the normal and the view direction. Defaults to None. + + Returns: + normal (np.ndarray): shape (height, width, 3), normal map. + """ + height, width = point.shape[-3:-1] + has_mask = mask is not None + + if mask is None: + mask = np.ones_like(point[..., 0], dtype=bool) + mask_pad = np.zeros((height + 2, width + 2), dtype=bool) + mask_pad[1:-1, 1:-1] = mask + mask = mask_pad + + pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype) + pts[1:-1, 1:-1, :] = point + up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :] + left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :] + down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :] + right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :] + normal = np.stack( + [ + np.cross(up, left, axis=-1), + np.cross(left, down, axis=-1), + np.cross(down, right, axis=-1), + np.cross(right, up, axis=-1), + ] + ) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) + + valid = ( + np.stack( + [ + mask[:-2, 1:-1] & mask[1:-1, :-2], + mask[1:-1, :-2] & mask[2:, 1:-1], + mask[2:, 1:-1] & mask[1:-1, 2:], + mask[1:-1, 2:] & mask[:-2, 1:-1], + ] + ) + & mask[None, 1:-1, 1:-1] + ) + if edge_threshold is not None: + view_angle = angle_diff_vec3_numpy(pts[None, 1:-1, 1:-1, :], normal) + view_angle = np.minimum(view_angle, np.pi - view_angle) + valid = valid & (view_angle < np.deg2rad(edge_threshold)) + + normal = (normal * valid[..., None]).sum(axis=0) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) + + if has_mask: + normal_mask = valid.any(axis=0) + normal = np.where(normal_mask[..., None], normal, 0) + return normal, normal_mask + else: + return normal + + +def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1): + """ + Create a sliding window view of the input array along a specified axis. + + This function creates a memory-efficient view of the input array with sliding windows + of the specified size and stride. The window dimension is appended to the end of the + output array's shape. This is useful for operations like convolution, pooling, or + any analysis that requires examining local neighborhoods in the data. + + Args: + x (np.ndarray): Input array with shape (..., axis_size, ...) + window_size (int): Size of the sliding window + stride (int): Stride of the sliding window (step size between consecutive windows) + axis (int, optional): Axis to perform sliding window over. Defaults to -1 (last axis) + + Returns: + np.ndarray: View of the input array with shape (..., n_windows, ..., window_size), + where n_windows = (axis_size - window_size + 1) // stride + + Raises: + AssertionError: If window_size is larger than the size of the specified axis + + Example: + >>> x = np.array([1, 2, 3, 4, 5, 6]) + >>> sliding_window_1d(x, window_size=3, stride=2) + array([[1, 2, 3], + [3, 4, 5]]) + """ + assert x.shape[axis] >= window_size, ( + f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})" + ) + axis = axis % x.ndim + shape = ( + *x.shape[:axis], + (x.shape[axis] - window_size + 1) // stride, + *x.shape[axis + 1 :], + window_size, + ) + strides = ( + *x.strides[:axis], + stride * x.strides[axis], + *x.strides[axis + 1 :], + x.strides[axis], + ) + x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + return x_sliding + + +def sliding_window_nd( + x: np.ndarray, + window_size: Tuple[int, ...], + stride: Tuple[int, ...], + axis: Tuple[int, ...], +) -> np.ndarray: + """ + Create sliding windows along multiple dimensions of the input array. + + This function applies sliding_window_1d sequentially along multiple axes to create + N-dimensional sliding windows. This is useful for operations that need to examine + local neighborhoods in multiple dimensions simultaneously. + + Args: + x (np.ndarray): Input array + window_size (Tuple[int, ...]): Size of the sliding window for each axis + stride (Tuple[int, ...]): Stride of the sliding window for each axis + axis (Tuple[int, ...]): Axes to perform sliding window over + + Returns: + np.ndarray: Array with sliding windows along the specified dimensions. + The window dimensions are appended to the end of the shape. + + Note: + The length of window_size, stride, and axis tuples must be equal. + + Example: + >>> x = np.random.rand(10, 10) + >>> windows = sliding_window_nd(x, window_size=(3, 3), stride=(2, 2), axis=(-2, -1)) + >>> # Creates 3x3 sliding windows with stride 2 in both dimensions + """ + axis = [axis[i] % x.ndim for i in range(len(axis))] + for i in range(len(axis)): + x = sliding_window_1d(x, window_size[i], stride[i], axis[i]) + return x + + +def sliding_window_2d( + x: np.ndarray, + window_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + axis: Tuple[int, int] = (-2, -1), +) -> np.ndarray: + """ + Create 2D sliding windows over the input array. + + Convenience function for creating 2D sliding windows, commonly used for image + processing operations like convolution, pooling, or patch extraction. + + Args: + x (np.ndarray): Input array + window_size (Union[int, Tuple[int, int]]): Size of the 2D sliding window. + If int, same size is used for both dimensions. + stride (Union[int, Tuple[int, int]]): Stride of the 2D sliding window. + If int, same stride is used for both dimensions. + axis (Tuple[int, int], optional): Two axes to perform sliding window over. + Defaults to (-2, -1) (last two dimensions). + + Returns: + np.ndarray: Array with 2D sliding windows. The window dimensions (height, width) + are appended to the end of the shape. + + Example: + >>> image = np.random.rand(100, 100) + >>> patches = sliding_window_2d(image, window_size=8, stride=4) + >>> # Creates 8x8 patches with stride 4 from the image + """ + if isinstance(window_size, int): + window_size = (window_size, window_size) + if isinstance(stride, int): + stride = (stride, stride) + return sliding_window_nd(x, window_size, stride, axis) + + +def max_pool_1d( + x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1 +): + """ + Perform 1D max pooling on the input array. + + Max pooling reduces the dimensionality of the input by taking the maximum value + within each sliding window. This is commonly used in neural networks and signal + processing for downsampling and feature extraction. + + Args: + x (np.ndarray): Input array + kernel_size (int): Size of the pooling kernel + stride (int): Stride of the pooling operation + padding (int, optional): Amount of padding to add on both sides. Defaults to 0. + axis (int, optional): Axis to perform max pooling over. Defaults to -1. + + Returns: + np.ndarray: Max pooled array with reduced size along the specified axis + + Note: + - For floating point arrays, padding is done with np.nan values + - For integer arrays, padding is done with the minimum value of the dtype + - np.nanmax is used to handle NaN values in the computation + + Example: + >>> x = np.array([1, 3, 2, 4, 5, 1, 2]) + >>> max_pool_1d(x, kernel_size=3, stride=2) + array([3, 5, 2]) + """ + axis = axis % x.ndim + if padding > 0: + fill_value = np.nan if x.dtype.kind == "f" else np.iinfo(x.dtype).min + padding_arr = np.full( + (*x.shape[:axis], padding, *x.shape[axis + 1 :]), + fill_value=fill_value, + dtype=x.dtype, + ) + x = np.concatenate([padding_arr, x, padding_arr], axis=axis) + a_sliding = sliding_window_1d(x, kernel_size, stride, axis) + max_pool = np.nanmax(a_sliding, axis=-1) + return max_pool + + +def max_pool_nd( + x: np.ndarray, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + axis: Tuple[int, ...], +) -> np.ndarray: + """ + Perform N-dimensional max pooling on the input array. + + This function applies max_pool_1d sequentially along multiple axes to perform + multi-dimensional max pooling. This is useful for downsampling multi-dimensional + data while preserving the most important features. + + Args: + x (np.ndarray): Input array + kernel_size (Tuple[int, ...]): Size of the pooling kernel for each axis + stride (Tuple[int, ...]): Stride of the pooling operation for each axis + padding (Tuple[int, ...]): Amount of padding for each axis + axis (Tuple[int, ...]): Axes to perform max pooling over + + Returns: + np.ndarray: Max pooled array with reduced size along the specified axes + + Note: + The length of kernel_size, stride, padding, and axis tuples must be equal. + Max pooling is applied sequentially along each axis in the order specified. + + Example: + >>> x = np.random.rand(10, 10, 10) + >>> pooled = max_pool_nd(x, kernel_size=(2, 2, 2), stride=(2, 2, 2), + ... padding=(0, 0, 0), axis=(-3, -2, -1)) + >>> # Reduces each dimension by half with 2x2x2 max pooling + """ + for i in range(len(axis)): + x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i]) + return x + + +def max_pool_2d( + x: np.ndarray, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int]], + axis: Tuple[int, int] = (-2, -1), +): + """ + Perform 2D max pooling on the input array. + + Convenience function for 2D max pooling, commonly used in computer vision + and image processing for downsampling images while preserving important features. + + Args: + x (np.ndarray): Input array + kernel_size (Union[int, Tuple[int, int]]): Size of the 2D pooling kernel. + If int, same size is used for both dimensions. + stride (Union[int, Tuple[int, int]]): Stride of the 2D pooling operation. + If int, same stride is used for both dimensions. + padding (Union[int, Tuple[int, int]]): Amount of padding for both dimensions. + If int, same padding is used for both dimensions. + axis (Tuple[int, int], optional): Two axes to perform max pooling over. + Defaults to (-2, -1) (last two dimensions). + + Returns: + np.ndarray: 2D max pooled array with reduced size along the specified axes + + Example: + >>> image = np.random.rand(64, 64) + >>> pooled = max_pool_2d(image, kernel_size=2, stride=2, padding=0) + >>> # Reduces image size from 64x64 to 32x32 with 2x2 max pooling + """ + if isinstance(kernel_size, Number): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, Number): + stride = (stride, stride) + if isinstance(padding, Number): + padding = (padding, padding) + axis = tuple(axis) + return max_pool_nd(x, kernel_size, stride, padding, axis) + + +@no_warnings(category=RuntimeWarning) +def depth_edge( + depth: np.ndarray, + atol: float = None, + rtol: float = None, + kernel_size: int = 3, + mask: np.ndarray = None, +) -> np.ndarray: + """ + Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth. + + Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + if mask is None: + diff = max_pool_2d( + depth, kernel_size, stride=1, padding=kernel_size // 2 + ) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + else: + diff = max_pool_2d( + np.where(mask, depth, -np.inf), + kernel_size, + stride=1, + padding=kernel_size // 2, + ) + max_pool_2d( + np.where(mask, -depth, -np.inf), + kernel_size, + stride=1, + padding=kernel_size // 2, + ) + + edge = np.zeros_like(depth, dtype=bool) + if atol is not None: + edge |= diff > atol + + if rtol is not None: + edge |= diff / depth > rtol + return edge + + +def depth_aliasing( + depth: np.ndarray, + atol: float = None, + rtol: float = None, + kernel_size: int = 3, + mask: np.ndarray = None, +) -> np.ndarray: + """ + Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. + Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + if mask is None: + diff_max = ( + max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth + ) + diff_min = ( + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth + ) + else: + diff_max = ( + max_pool_2d( + np.where(mask, depth, -np.inf), + kernel_size, + stride=1, + padding=kernel_size // 2, + ) + - depth + ) + diff_min = ( + max_pool_2d( + np.where(mask, -depth, -np.inf), + kernel_size, + stride=1, + padding=kernel_size // 2, + ) + + depth + ) + diff = np.minimum(diff_max, diff_min) + + edge = np.zeros_like(depth, dtype=bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= diff / depth > rtol + return edge + + +@no_warnings(category=RuntimeWarning) +def normals_edge( + normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None +) -> np.ndarray: + """ + Compute the edge mask from normal map. + + Args: + normal (np.ndarray): shape (..., height, width, 3), normal map + tol (float): tolerance in degrees + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + assert normals.ndim >= 3 and normals.shape[-1] == 3, ( + "normal should be of shape (..., height, width, 3)" + ) + normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12) + + padding = kernel_size // 2 + normals_window = sliding_window_2d( + np.pad( + normals, + ( + *([(0, 0)] * (normals.ndim - 3)), + (padding, padding), + (padding, padding), + (0, 0), + ), + mode="edge", + ), + window_size=kernel_size, + stride=1, + axis=(-3, -2), + ) + if mask is None: + angle_diff = np.arccos( + (normals[..., None, None] * normals_window).sum(axis=-3) + ).max(axis=(-2, -1)) + else: + mask_window = sliding_window_2d( + np.pad( + mask, + (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)), + mode="edge", + ), + window_size=kernel_size, + stride=1, + axis=(-3, -2), + ) + angle_diff = np.where( + mask_window, + np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)), + 0, + ).max(axis=(-2, -1)) + + angle_diff = max_pool_2d( + angle_diff, kernel_size, stride=1, padding=kernel_size // 2 + ) + edge = angle_diff > np.deg2rad(tol) + return edge diff --git a/mapanything/utils/hf_utils/__init__.py b/mapanything/utils/hf_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mapanything/utils/hf_utils/__pycache__/__init__.cpython-312.pyc b/mapanything/utils/hf_utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e839fbfdc2137fc647b044105a4e56a4965cec20 Binary files /dev/null and b/mapanything/utils/hf_utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/mapanything/utils/hf_utils/__pycache__/css_and_html.cpython-312.pyc b/mapanything/utils/hf_utils/__pycache__/css_and_html.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74dee2b9b9ccad0f64d79dd3fed215203f2efd90 Binary files /dev/null and b/mapanything/utils/hf_utils/__pycache__/css_and_html.cpython-312.pyc differ diff --git a/mapanything/utils/hf_utils/__pycache__/hf_helpers.cpython-312.pyc b/mapanything/utils/hf_utils/__pycache__/hf_helpers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..015b2d2f6c20f851c92b868b5f399bfdfd13d7ee Binary files /dev/null and b/mapanything/utils/hf_utils/__pycache__/hf_helpers.cpython-312.pyc differ diff --git a/mapanything/utils/hf_utils/__pycache__/viz.cpython-312.pyc b/mapanything/utils/hf_utils/__pycache__/viz.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4475b84add90b82d5fd678fef9d968906d2f4caf Binary files /dev/null and b/mapanything/utils/hf_utils/__pycache__/viz.cpython-312.pyc differ diff --git a/mapanything/utils/hf_utils/css_and_html.py b/mapanything/utils/hf_utils/css_and_html.py new file mode 100644 index 0000000000000000000000000000000000000000..a303fb1a6061800dce0677bddc2eca4e8516dd35 --- /dev/null +++ b/mapanything/utils/hf_utils/css_and_html.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +CSS and HTML content for the MapAnything Gradio application. +This module contains all the CSS styles and HTML content blocks +used in the Gradio interface. +""" + +# CSS Styles for the Gradio interface +GRADIO_CSS = """ +.custom-log * { + font-style: italic; + font-size: 22px !important; + background-image: linear-gradient(120deg, #ffb366 0%, #ffa366 60%, #ff9966 100%); + -webkit-background-clip: text; + background-clip: text; + font-weight: bold !important; + color: transparent !important; + text-align: center !important; +} + +.example-log * { + font-style: italic; + font-size: 16px !important; + background-image: linear-gradient(120deg, #ffb366 0%, #ffa366 60%, #ff9966 100%); + -webkit-background-clip: text; + background-clip: text; + color: transparent !important; +} + +#my_radio .wrap { + display: flex; + flex-wrap: nowrap; + justify-content: center; + align-items: center; +} + +#my_radio .wrap label { + display: flex; + width: 50%; + justify-content: center; + align-items: center; + margin: 0; + padding: 10px 0; + box-sizing: border-box; +} + +/* Align navigation buttons with dropdown bottom */ +.navigation-row { + display: flex !important; + align-items: flex-end !important; + gap: 8px !important; +} + +.navigation-row > div:nth-child(1), +.navigation-row > div:nth-child(3) { + align-self: flex-end !important; +} + +.navigation-row > div:nth-child(2) { + flex: 1 !important; +} + +/* Make thumbnails clickable with pointer cursor */ +.clickable-thumbnail img { + cursor: pointer !important; +} + +.clickable-thumbnail:hover img { + cursor: pointer !important; + opacity: 0.8; + transition: opacity 0.3s ease; +} + +/* Make thumbnail containers narrower horizontally */ +.clickable-thumbnail { + padding: 5px 2px !important; + margin: 0 2px !important; +} + +.clickable-thumbnail .image-container { + margin: 0 !important; + padding: 0 !important; +} + +.scene-info { + text-align: center !important; + padding: 5px 2px !important; + margin: 0 !important; +} +""" + + +def get_header_html(logo_base64=None): + """ + Generate the main header HTML with logo and title. + + Args: + logo_base64 (str, optional): Base64 encoded logo image + + Returns: + str: HTML string for the header + """ + logo_style = "display: none;" if not logo_base64 else "" + logo_src = logo_base64 or "" + + return f""" +
+ 🌟 GitHub Repository | + 🚀 Project Page +
+ """ + + +def get_description_html(): + """ + Generate the main description and getting started HTML. + + Returns: + str: HTML string for the description + """ + return """ +Upload a video or a set of images to create a 3D reconstruction of a scene or object. MapAnything takes these images and generates 3D point clouds directly from multi-view images.
+This demo demonstrates the use of image inputs only. However, MapAnything is extremely flexible and supports any combination of inputs (images, calibration, poses & depth). For trying out memory efficient inference or additional inputs like cameras & depth, please check out the code in our Github repo.
+ +Please note: The inference time changes based on the amount of input images, for e.g., less than 1 second for up to 50 views. However, downloading model weights and visualizing 3D points may take tens of seconds. Please be patient or, for faster visualization, use a local machine to run our demo from our GitHub repository.
+This site builds upon code from:
+ +We extend our gratitude to these projects for their valuable contributions to the research community.
+基于DBSCAN聚类的智能物体识别 | 多视图融合 | 自适应参数调整
+