# Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. """ Inference wrapper for DUSt3R """ import warnings import torch from dust3r.cloud_opt import global_aligner, GlobalAlignerMode from dust3r.image_pairs import make_pairs from dust3r.inference import inference from dust3r.model import AsymmetricCroCo3DStereo # noqa from mapanything.models.external.vggt.utils.rotation import mat_to_quat from mapanything.utils.geometry import ( convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, convert_z_depth_to_depth_along_ray, depthmap_to_camera_frame, get_rays_in_camera_frame, ) inf = float("inf") def load_model(model_path, device, verbose=True): if verbose: print("Loading model from", model_path) ckpt = torch.load(model_path, map_location="cpu", weights_only=False) args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") if "landscape_only" not in args: args = args[:-1] + ", landscape_only=False)" else: args = args.replace(" ", "").replace( "landscape_only=True", "landscape_only=False" ) assert "landscape_only=False" in args if verbose: print(f"Instantiating: {args}") try: net = eval(args) except NameError: net = AsymmetricCroCo3DStereo( enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, enc_num_heads=16, dec_num_heads=12, pos_embed="RoPE100", patch_embed_cls="PatchEmbedDust3R", img_size=(512, 512), head_type="dpt", output_mode="pts3d", depth_mode=("exp", -inf, inf), conf_mode=("exp", 1, inf), landscape_only=False, ) s = net.load_state_dict(ckpt["model"], strict=False) if verbose: print(s) return net.to(device) class DUSt3RBAWrapper(torch.nn.Module): def __init__( self, name, ckpt_path, scene_graph="complete", inference_batch_size=32, global_optim_schedule="cosine", global_optim_lr=0.01, global_optim_niter=300, **kwargs, ): super().__init__() self.name = name self.ckpt_path = ckpt_path self.scene_graph = scene_graph self.inference_batch_size = inference_batch_size self.global_optim_schedule = global_optim_schedule self.global_optim_lr = global_optim_lr self.global_optim_niter = global_optim_niter # Init the model and load the checkpoint self.model = load_model(self.ckpt_path, device="cpu") # Init the global aligner mode self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer def forward(self, views): """ Forward pass wrapper for DUSt3R using the global aligner. Assumption: - The batch size of input views is 1. Args: views (List[dict]): List of dictionaries containing the input views' images and instance information. Each dictionary should contain the following keys, where B is the batch size and is 1: "img" (tensor): Image tensor of shape (B, C, H, W). "data_norm_type" (list): ["dust3r"] Returns: List[dict]: A list containing the final outputs for the input views. """ # Check the batch size of input views batch_size_per_view, _, height, width = views[0]["img"].shape device = views[0]["img"].device num_views = len(views) assert batch_size_per_view == 1, ( f"Batch size of input views should be 1, but got {batch_size_per_view}." ) # Check the data norm type data_norm_type = views[0]["data_norm_type"][0] assert data_norm_type == "dust3r", ( "DUSt3R expects a normalized image with the DUSt3R normalization scheme applied" ) # Convert the input views to the expected input format images = [] for view in views: images.append( dict( img=view["img"], idx=len(images), instance=str(len(images)), ) ) # Make image pairs and run inference pair-wise pairs = make_pairs( images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True ) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=FutureWarning) output = inference( pairs, self.model, device, batch_size=self.inference_batch_size, verbose=False, ) # Global optimization with torch.enable_grad(): scene = global_aligner( output, device=device, mode=self.global_aligner_mode, verbose=False ) _ = scene.compute_global_alignment( init="mst", niter=self.global_optim_niter, schedule=self.global_optim_schedule, lr=self.global_optim_lr, ) # Make sure scene is not None if scene is None: raise RuntimeError("Global optimization failed.") # Get the predictions intrinsics = scene.get_intrinsics() c2w_poses = scene.get_im_poses() depths = scene.get_depthmaps() # Convert the output to the MapAnything format with torch.autocast("cuda", enabled=False): res = [] for view_idx in range(num_views): # Get the current view predictions curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0) curr_view_pose = c2w_poses[view_idx].unsqueeze(0) curr_view_depth_z = depths[view_idx].unsqueeze(0) # Convert the pose to quaternions and translation curr_view_cam_translations = curr_view_pose[..., :3, 3] curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) # Get the camera frame pointmaps curr_view_pts3d_cam, _ = depthmap_to_camera_frame( curr_view_depth_z, curr_view_intrinsic ) # Convert the z depth to depth along ray curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( curr_view_depth_z, curr_view_intrinsic ) curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) # Get the ray directions on the unit sphere in the camera frame _, curr_view_ray_dirs = get_rays_in_camera_frame( curr_view_intrinsic, height, width, normalize_to_unit_sphere=True ) # Get the pointmaps curr_view_pts3d = ( convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( curr_view_ray_dirs, curr_view_depth_along_ray, curr_view_cam_translations, curr_view_cam_quats, ) ) # Append the outputs to the result list res.append( { "pts3d": curr_view_pts3d, "pts3d_cam": curr_view_pts3d_cam, "ray_directions": curr_view_ray_dirs, "depth_along_ray": curr_view_depth_along_ray, "cam_trans": curr_view_cam_translations, "cam_quats": curr_view_cam_quats, } ) return res