Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| """ | |
| Inference wrapper for MASt3R + Sparse GA | |
| """ | |
| import os | |
| import tempfile | |
| import warnings | |
| import torch | |
| from dust3r.image_pairs import make_pairs | |
| from mast3r.cloud_opt.sparse_ga import sparse_global_alignment | |
| from mast3r.model import load_model | |
| from mapanything.models.external.vggt.utils.rotation import mat_to_quat | |
| from mapanything.utils.geometry import ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, | |
| convert_z_depth_to_depth_along_ray, | |
| depthmap_to_camera_frame, | |
| get_rays_in_camera_frame, | |
| ) | |
| class MASt3RSGAWrapper(torch.nn.Module): | |
| def __init__( | |
| self, | |
| name, | |
| ckpt_path, | |
| cache_dir, | |
| scene_graph="complete", | |
| sparse_ga_lr1=0.07, | |
| sparse_ga_niter1=300, | |
| sparse_ga_lr2=0.01, | |
| sparse_ga_niter2=300, | |
| sparse_ga_optim_level="refine+depth", | |
| sparse_ga_shared_intrinsics=False, | |
| sparse_ga_matching_conf_thr=5.0, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.name = name | |
| self.ckpt_path = ckpt_path | |
| self.cache_dir = cache_dir | |
| self.scene_graph = scene_graph | |
| self.sparse_ga_lr1 = sparse_ga_lr1 | |
| self.sparse_ga_niter1 = sparse_ga_niter1 | |
| self.sparse_ga_lr2 = sparse_ga_lr2 | |
| self.sparse_ga_niter2 = sparse_ga_niter2 | |
| self.sparse_ga_optim_level = sparse_ga_optim_level | |
| self.sparse_ga_shared_intrinsics = sparse_ga_shared_intrinsics | |
| self.sparse_ga_matching_conf_thr = sparse_ga_matching_conf_thr | |
| # Init the model and load the checkpoint | |
| self.model = load_model(self.ckpt_path, device="cpu") | |
| def forward(self, views): | |
| """ | |
| Forward pass wrapper for MASt3R using the sparse global aligner. | |
| Assumption: | |
| - The batch size of input views is 1. | |
| Args: | |
| views (List[dict]): List of dictionaries containing the input views' images and instance information. | |
| Each dictionary should contain the following keys, where B is the batch size and is 1: | |
| "img" (tensor): Image tensor of shape (B, C, H, W). | |
| "data_norm_type" (list): ["dust3r"] | |
| "label" (list): ["scene_name"] | |
| "instance" (list): ["image_name"] | |
| Returns: | |
| List[dict]: A list containing the final outputs for the input views. | |
| """ | |
| # Check the batch size of input views | |
| batch_size_per_view, _, height, width = views[0]["img"].shape | |
| device = views[0]["img"].device | |
| num_views = len(views) | |
| assert batch_size_per_view == 1, ( | |
| f"Batch size of input views should be 1, but got {batch_size_per_view}." | |
| ) | |
| # Check the data norm type | |
| data_norm_type = views[0]["data_norm_type"][0] | |
| assert data_norm_type == "dust3r", ( | |
| "MASt3R expects a normalized image with the DUSt3R normalization scheme applied" | |
| ) | |
| # Convert the input views to the expected input format | |
| images = [] | |
| image_paths = [] | |
| for view in views: | |
| images.append( | |
| dict( | |
| img=view["img"].cpu(), | |
| idx=len(images), | |
| instance=str(len(images)), | |
| true_shape=torch.tensor(view["img"].shape[-2:])[None] | |
| .repeat(batch_size_per_view, 1) | |
| .numpy(), | |
| ) | |
| ) | |
| view_name = os.path.join(view["label"][0], view["instance"][0]) | |
| image_paths.append(view_name) | |
| # Make image pairs and run inference | |
| # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation) | |
| pairs = make_pairs( | |
| images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True | |
| ) | |
| with torch.enable_grad(): | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore", category=FutureWarning) | |
| tempfile.mkdtemp(dir=self.cache_dir) | |
| scene = sparse_global_alignment( | |
| image_paths, | |
| pairs, | |
| self.cache_dir, | |
| self.model, | |
| lr1=self.sparse_ga_lr1, | |
| niter1=self.sparse_ga_niter1, | |
| lr2=self.sparse_ga_lr2, | |
| niter2=self.sparse_ga_niter2, | |
| device=device, | |
| opt_depth="depth" in self.sparse_ga_optim_level, | |
| shared_intrinsics=self.sparse_ga_shared_intrinsics, | |
| matching_conf_thr=self.sparse_ga_matching_conf_thr, | |
| verbose=False, | |
| ) | |
| # Make sure scene is not None | |
| if scene is None: | |
| raise RuntimeError("Global optimization failed.") | |
| # Get the predictions | |
| intrinsics = scene.intrinsics | |
| c2w_poses = scene.get_im_poses() | |
| _, depths, _ = scene.get_dense_pts3d() | |
| # Convert the output to the MapAnything format | |
| with torch.autocast("cuda", enabled=False): | |
| res = [] | |
| for view_idx in range(num_views): | |
| # Get the current view predictions | |
| curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0) | |
| curr_view_pose = c2w_poses[view_idx].unsqueeze(0) | |
| curr_view_depth_z = ( | |
| depths[view_idx].reshape((height, width)).unsqueeze(0) | |
| ) | |
| # Convert the pose to quaternions and translation | |
| curr_view_cam_translations = curr_view_pose[..., :3, 3] | |
| curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) | |
| # Get the camera frame pointmaps | |
| curr_view_pts3d_cam, _ = depthmap_to_camera_frame( | |
| curr_view_depth_z, curr_view_intrinsic | |
| ) | |
| # Convert the z depth to depth along ray | |
| curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( | |
| curr_view_depth_z, curr_view_intrinsic | |
| ) | |
| curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) | |
| # Get the ray directions on the unit sphere in the camera frame | |
| _, curr_view_ray_dirs = get_rays_in_camera_frame( | |
| curr_view_intrinsic, height, width, normalize_to_unit_sphere=True | |
| ) | |
| # Get the pointmaps | |
| curr_view_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| curr_view_ray_dirs, | |
| curr_view_depth_along_ray, | |
| curr_view_cam_translations, | |
| curr_view_cam_quats, | |
| ) | |
| ) | |
| # Append the outputs to the result list | |
| res.append( | |
| { | |
| "pts3d": curr_view_pts3d, | |
| "pts3d_cam": curr_view_pts3d_cam, | |
| "ray_directions": curr_view_ray_dirs, | |
| "depth_along_ray": curr_view_depth_along_ray, | |
| "cam_trans": curr_view_cam_translations, | |
| "cam_quats": curr_view_cam_quats, | |
| } | |
| ) | |
| return res | |