# Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. """ Inference wrapper for MUSt3R """ import datetime import os import numpy as np import torch from dust3r.viz import rgb from must3r.demo.inference import SceneState from must3r.engine.inference import inference_multi_ar, postprocess from must3r.model import get_pointmaps_activation, load_model from mapanything.models.external.vggt.utils.rotation import mat_to_quat def must3r_inference( views, filelist, model, retrieval, device, amp, num_mem_images, max_bs, init_num_images=2, batch_num_views=1, render_once=False, is_sequence=False, viser_server=None, num_refinements_iterations=2, verbose=True, ): if amp == "fp16": dtype = torch.float16 elif amp == "bf16": assert torch.cuda.is_bf16_supported() dtype = torch.bfloat16 else: assert not amp dtype = torch.float32 max_bs = None if max_bs == 0 else max_bs encoder, decoder = model pointmaps_activation = get_pointmaps_activation(decoder, verbose=verbose) def post_process_function(x): return postprocess( x, pointmaps_activation=pointmaps_activation, compute_cam=True ) if verbose: print("loading images") time_start = datetime.datetime.now() nimgs = len(views) ellapsed = datetime.datetime.now() - time_start if verbose: print(f"loaded in {ellapsed}") print("running inference") time_start = datetime.datetime.now() if viser_server is not None: viser_server.reset(nimgs) imgs = [b["img"].to("cpu") for b in views] true_shape = [torch.from_numpy(b["true_shape"]).to("cpu") for b in views] true_shape = torch.stack(true_shape, dim=0) nimgs = true_shape.shape[0] # Use all images as keyframes keyframes = np.linspace(0, len(imgs) - 1, num_mem_images, dtype=int).tolist() encoder_precomputed_features = None not_keyframes = sorted(set(range(nimgs)).difference(set(keyframes))) assert (len(keyframes) + len(not_keyframes)) == nimgs # reorder images views = [views[i] for i in keyframes] + [views[i] for i in not_keyframes] imgs = [b["img"].to(device) for b in views] true_shape = [torch.from_numpy(b["true_shape"]).to(device) for b in views] filenames = [filelist[i] for i in keyframes + not_keyframes] img_ids = [torch.tensor(v) for v in keyframes + not_keyframes] if encoder_precomputed_features is not None: x_start, pos_start = encoder_precomputed_features x = [x_start[i] for i in keyframes] + [x_start[i] for i in not_keyframes] pos = [pos_start[i] for i in keyframes] + [pos_start[i] for i in not_keyframes] encoder_precomputed_features = (x, pos) mem_batches = [init_num_images] while (sum_b := sum(mem_batches)) != max(num_mem_images, init_num_images): size_b = min(batch_num_views, num_mem_images - sum_b) mem_batches.append(size_b) if render_once: to_render = list(range(num_mem_images, nimgs)) else: to_render = None with torch.autocast("cuda", dtype=dtype): x_out_0, x_out = inference_multi_ar( encoder, decoder, imgs, img_ids, true_shape, mem_batches, max_bs=max_bs, verbose=verbose, to_render=to_render, encoder_precomputed_features=encoder_precomputed_features, device=device, preserve_gpu_mem=True, post_process_function=post_process_function, viser_server=viser_server, num_refinements_iterations=num_refinements_iterations, ) if to_render is not None: x_out = x_out_0 + x_out ellapsed = datetime.datetime.now() - time_start if verbose: print(f"inference in {ellapsed}") try: print(str(int(torch.cuda.max_memory_reserved(device) / (1024**2))) + " MB") except Exception: pass if viser_server is not None: viser_server.reset_cam_visility() viser_server.send_message("Finished") if verbose: print("preparing pointcloud") time_start = datetime.datetime.now() focals = [] cams2world = [] for i in range(nimgs): focals.append(float(x_out[i]["focal"].cpu())) cams2world.append(x_out[i]["c2w"].cpu()) # x_out to cpu for i in range(len(x_out)): for k in x_out[i].keys(): x_out[i][k] = x_out[i][k].cpu() rgbimg = [rgb(imgs[i], true_shape[i]) for i in range(nimgs)] scene = SceneState(x_out, rgbimg, true_shape, focals, cams2world, filenames) ellapsed = datetime.datetime.now() - time_start if verbose: print(f"pointcloud prepared in {ellapsed}") return scene class MUSt3RWrapper(torch.nn.Module): def __init__( self, name, ckpt_path, retrieval_ckpt_path, img_size=512, amp="bf16", max_bs=1, **kwargs, ): super().__init__() self.name = name self.ckpt_path = ckpt_path self.retrieval_ckpt_path = retrieval_ckpt_path self.amp = amp self.max_bs = max_bs # Init the model and load the checkpoint self.model = load_model(self.ckpt_path, img_size=512) def forward(self, views): """ Forward pass wrapper for MUSt3R. Assumption: - The batch size of input views is 1. Args: views (List[dict]): List of dictionaries containing the input views' images and instance information. Each dictionary should contain the following keys, where B is the batch size and is 1: "img" (tensor): Image tensor of shape (B, C, H, W). "data_norm_type" (list): ["dust3r"] "label" (list): ["scene_name"] "instance" (list): ["image_name"] Returns: List[dict]: A list containing the final outputs for the input views. """ # Check the batch size of input views batch_size_per_view, _, height, width = views[0]["img"].shape device = views[0]["img"].device num_views = len(views) assert batch_size_per_view == 1, ( f"Batch size of input views should be 1, but got {batch_size_per_view}." ) # Check the data norm type data_norm_type = views[0]["data_norm_type"][0] assert data_norm_type == "dust3r", ( "MUSt3R expects a normalized image with the DUSt3R normalization scheme applied" ) # Convert the input views to the expected input format images = [] image_paths = [] for view in views: images.append( dict( img=view["img"][0].cpu(), idx=len(images), instance=str(len(images)), true_shape=np.int32([view["img"].shape[-2], view["img"].shape[-1]]), ) ) view_name = os.path.join(view["label"][0], view["instance"][0]) image_paths.append(view_name) # Run MUSt3R inference scene = must3r_inference( images, image_paths, self.model, self.retrieval_ckpt_path, device, self.amp, num_views, self.max_bs, verbose=False, ) # Make sure scene is not None if scene is None: raise RuntimeError("MUSt3R failed.") # Get the predictions predictions = scene.x_out # Convert the output to the MapAnything format with torch.autocast("cuda", enabled=False): res = [] for view_idx in range(num_views): # Get the current view predictions curr_view_prediction = predictions[view_idx] curr_view_conf = curr_view_prediction["conf"] curr_view_pose = curr_view_prediction["c2w"].unsqueeze(0) # Convert the pose to quaternions and translation curr_view_cam_translations = curr_view_pose[..., :3, 3] curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) # Get the camera frame pointmaps curr_view_pts3d_cam = curr_view_prediction["pts3d_local"].unsqueeze(0) # Get the depth along ray and ray directions curr_view_depth_along_ray = torch.norm( curr_view_pts3d_cam, dim=-1, keepdim=True ) curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray # Get the pointmaps curr_view_pts3d = curr_view_prediction["pts3d"].unsqueeze(0) # Append the outputs to the result list res.append( { "pts3d": curr_view_pts3d.to(device), "pts3d_cam": curr_view_pts3d_cam.to(device), "ray_directions": curr_view_ray_dirs.to(device), "depth_along_ray": curr_view_depth_along_ray.to(device), "cam_trans": curr_view_cam_translations.to(device), "cam_quats": curr_view_cam_quats.to(device), "conf": curr_view_conf.to(device), } ) return res