# Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. """ Inference wrapper for Pow3R """ import warnings from copy import deepcopy import pow3r.model.blocks # noqa import roma import torch import torch.nn as nn import tqdm from dust3r.cloud_opt import global_aligner, GlobalAlignerMode from dust3r.image_pairs import make_pairs from dust3r.inference import check_if_same_size from dust3r.model import CroCoNet from dust3r.patch_embed import get_patch_embed as dust3r_patch_embed from dust3r.utils.device import collate_with_cat, to_cpu from dust3r.utils.misc import ( fill_default_args, freeze_all_params, interleave, is_symmetrized, transpose_to_landscape, ) from pow3r.model.blocks import Block, BlockInject, DecoderBlock, DecoderBlockInject, Mlp from pow3r.model.heads import head_factory from pow3r.model.inference import ( add_depth, add_intrinsics, add_relpose, ) from pow3r.model.patch_embed import get_patch_embed from mapanything.models.external.vggt.utils.rotation import mat_to_quat from mapanything.utils.geometry import ( convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, convert_z_depth_to_depth_along_ray, depthmap_to_camera_frame, get_rays_in_camera_frame, ) class Pow3R(CroCoNet): """Two siamese encoders, followed by two decoders. The goal is to output 3d points directly, both images in view1's frame (hence the asymmetry). """ def __init__( self, mode="embed", head_type="linear", patch_embed_cls="PatchEmbedDust3R", freeze="none", landscape_only=True, **croco_kwargs, ): # retrieve all default arguments using python magic self.croco_args = fill_default_args(croco_kwargs, super().__init__) super().__init__(**croco_kwargs) del self.mask_token # useless del self.prediction_head dec_dim, enc_dim = self.decoder_embed.weight.shape self.enc_embed_dim = enc_dim self.dec_embed_dim = dec_dim self.mode = mode # additional parameters in the encoder img_size = self.patch_embed.img_size patch_size = self.patch_embed.patch_size[0] self.patch_embed = dust3r_patch_embed( patch_embed_cls, img_size, patch_size, self.enc_embed_dim ) self.patch_embed_rays = get_patch_embed( patch_embed_cls + "_Mlp", img_size, patch_size, self.enc_embed_dim, in_chans=3, ) self.patch_embed_depth = get_patch_embed( patch_embed_cls + "_Mlp", img_size, patch_size, self.enc_embed_dim, in_chans=2, ) self.pose_embed = Mlp(12, 4 * dec_dim, dec_dim) # additional parameters in the decoder self.dec_cls = "_cls" in self.mode self.dec_num_cls = 0 if self.dec_cls: # use a CLS token in the decoder only self.mode = self.mode.replace("_cls", "") self.cls_token1 = nn.Parameter(torch.zeros((dec_dim,))) self.cls_token2 = nn.Parameter(torch.zeros((dec_dim,))) self.dec_num_cls = 1 # affects all blocks use_ln = "_ln" in self.mode # TODO remove? self.patch_ln = nn.LayerNorm(enc_dim) if use_ln else nn.Identity() self.dec1_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity() self.dec2_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity() self.dec_blocks2 = deepcopy(self.dec_blocks) # here we modify some of the blocks self.replace_some_blocks() self.set_downstream_head(head_type, landscape_only, **croco_kwargs) self.set_freeze(freeze) def replace_some_blocks(self): assert self.mode.startswith("inject") # inject[0,0.5] NewBlock = BlockInject DecoderNewBlock = DecoderBlockInject all_layers = { i / n for i in range(len(self.enc_blocks)) for n in [len(self.enc_blocks), len(self.dec_blocks)] } which_layers = eval(self.mode[self.mode.find("[") :]) or all_layers assert isinstance(which_layers, (set, list)) n = 0 for i, block in enumerate(self.enc_blocks): if i / len(self.enc_blocks) in which_layers: block.__class__ = NewBlock block.init(self.enc_embed_dim) n += 1 else: block.__class__ = Block assert n == len(which_layers), breakpoint() n = 0 for i in range(len(self.dec_blocks)): for blocks in [self.dec_blocks, self.dec_blocks2]: block = blocks[i] if i / len(self.dec_blocks) in which_layers: block.__class__ = DecoderNewBlock block.init(self.dec_embed_dim) n += 1 else: block.__class__ = DecoderBlock assert n == 2 * len(which_layers), breakpoint() @classmethod def from_pretrained(cls, pretrained_model_path, **kw): return _load_model(pretrained_model_path, device="cpu") def load_state_dict(self, ckpt, **kw): # duplicate all weights for the second decoder if not present new_ckpt = dict(ckpt) if not any(k.startswith("dec_blocks2") for k in ckpt): for key, value in ckpt.items(): if key.startswith("dec_blocks"): new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value # remove layers that have different shapes cur_ckpt = self.state_dict() for key, val in ckpt.items(): if key.startswith("downstream_head2.proj"): if key in cur_ckpt and cur_ckpt[key].shape != val.shape: print(f" (removing ckpt[{key}] because wrong shape)") del new_ckpt[key] return super().load_state_dict(new_ckpt, **kw) def set_freeze(self, freeze): # this is for use by downstream models self.freeze = freeze to_be_frozen = { "none": [], "encoder": [self.patch_embed, self.enc_blocks], } freeze_all_params(to_be_frozen[freeze]) def set_prediction_head(self, *args, **kwargs): """No prediction head""" return def set_downstream_head( self, head_type, landscape_only, patch_size, img_size, mlp_ratio, dec_depth, **kw, ): assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, ( f"{img_size=} must be multiple of {patch_size=}" ) # split heads if different heads = head_type.split(";") assert len(heads) in (1, 2) head1_type, head2_type = (heads + heads)[:2] # allocate heads self.downstream_head1 = head_factory(head1_type, self) self.downstream_head2 = head_factory(head2_type, self) # magic wrapper self.head1 = transpose_to_landscape( self.downstream_head1, activate=landscape_only ) self.head2 = transpose_to_landscape( self.downstream_head2, activate=landscape_only ) def _encode_image(self, image, true_shape, rays=None, depth=None): # embed the image into patches (x has size B x Npatches x C) x, pos = self.patch_embed(image, true_shape=true_shape) if rays is not None: # B,3,H,W rays_emb, pos2 = self.patch_embed_rays(rays, true_shape=true_shape) assert (pos == pos2).all() if self.mode.startswith("embed"): x = x + rays_emb else: rays_emb = None if depth is not None: # B,2,H,W depth_emb, pos2 = self.patch_embed_depth(depth, true_shape=true_shape) assert (pos == pos2).all() if self.mode.startswith("embed"): x = x + depth_emb else: depth_emb = None x = self.patch_ln(x) # add positional embedding without cls token assert self.enc_pos_embed is None # now apply the transformer encoder and normalization for blk in self.enc_blocks: x = blk(x, pos, rays=rays_emb, depth=depth_emb) x = self.enc_norm(x) return x, pos def encode_symmetrized(self, view1, view2): img1 = view1["img"] img2 = view2["img"] B = img1.shape[0] # Recover true_shape when available, otherwise assume that the img shape is the true one shape1 = view1.get( "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1) ) shape2 = view2.get( "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1) ) # warning! maybe the images have different portrait/landscape orientations # privileged information rays1 = view1.get("known_rays", None) rays2 = view2.get("known_rays", None) depth1 = view1.get("known_depth", None) depth2 = view2.get("known_depth", None) if is_symmetrized(view1, view2): # computing half of forward pass!' def hsub(x): return None if x is None else x[::2] feat1, pos1 = self._encode_image( img1[::2], shape1[::2], rays=hsub(rays1), depth=hsub(depth1) ) feat2, pos2 = self._encode_image( img2[::2], shape2[::2], rays=hsub(rays2), depth=hsub(depth2) ) feat1, feat2 = interleave(feat1, feat2) pos1, pos2 = interleave(pos1, pos2) else: feat1, pos1 = self._encode_image(img1, shape1, rays=rays1, depth=depth1) feat2, pos2 = self._encode_image(img2, shape2, rays=rays2, depth=depth2) return (shape1, shape2), (feat1, feat2), (pos1, pos2) def _decoder(self, f1, pos1, f2, pos2, relpose1=None, relpose2=None): final_output = [(f1, f2)] # before projection # project to decoder dim f1 = self.decoder_embed(f1) f2 = self.decoder_embed(f2) # add CLS token for the decoder if self.dec_cls: cls1 = self.cls_token1[None, None].expand(len(f1), 1, -1).clone() cls2 = self.cls_token2[None, None].expand(len(f2), 1, -1).clone() if relpose1 is not None: # shape = (B, 4, 4) pose_emb1 = self.pose_embed(relpose1[:, :3].flatten(1)).unsqueeze(1) if self.mode.startswith("embed"): if self.dec_cls: cls1 = cls1 + pose_emb1 else: f1 = f1 + pose_emb1 else: pose_emb1 = None if relpose2 is not None: # shape = (B, 4, 4) pose_emb2 = self.pose_embed(relpose2[:, :3].flatten(1)).unsqueeze(1) if self.mode.startswith("embed"): if self.dec_cls: cls2 = cls2 + pose_emb2 else: f2 = f2 + pose_emb2 else: pose_emb2 = None if self.dec_cls: f1, pos1 = cat_cls(cls1, f1, pos1) f2, pos2 = cat_cls(cls2, f2, pos2) f1 = self.dec1_pre_ln(f1) f2 = self.dec2_pre_ln(f2) final_output.append((f1, f2)) # to be removed later for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): # img1 side f1, _ = blk1( *final_output[-1][::+1], pos1, pos2, relpose=pose_emb1, num_cls=self.dec_num_cls, ) # img2 side f2, _ = blk2( *final_output[-1][::-1], pos2, pos1, relpose=pose_emb2, num_cls=self.dec_num_cls, ) # store the result final_output.append((f1, f2)) del final_output[1] # duplicate with final_output[0] (after decoder proj) if self.dec_cls: # remove cls token for decoder layers final_output[1:] = [(f1[:, 1:], f2[:, 1:]) for f1, f2 in final_output[1:]] # normalize last output final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) return zip(*final_output) def _downstream_head(self, head_num, decout, img_shape): B, S, D = decout[-1].shape head = getattr(self, f"head{head_num}") return head(decout, img_shape) def forward(self, view1, view2): # encode the two images --> B,S,D (shape1, shape2), (feat1, feat2), (pos1, pos2) = self.encode_symmetrized( view1, view2 ) # combine all ref images into object-centric representation dec1, dec2 = self._decoder( feat1, pos1, feat2, pos2, relpose1=view1.get("known_pose"), relpose2=view2.get("known_pose"), ) with torch.autocast("cuda", enabled=False): res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) res2["pts3d_in_other_view"] = res2.pop( "pts3d" ) # predict view2's pts3d in view1's frame return res1, res2 def convert_release_dust3r_args(args): args.model = ( args.model.replace("patch_embed_cls", "patch_embed") .replace("AsymmetricMASt3R", "AsymmetricCroCo3DStereo") .replace("PatchEmbedDust3R", "convManyAR") .replace( "pos_embed='RoPE100'", "enc_pos_embed='cuRoPE100', dec_pos_embed='cuRoPE100'", ) ) return args def _load_model(model_path, device): print("... loading model from", model_path) ckpt = torch.load(model_path, map_location="cpu") try: net = eval( ckpt["args"].model[:-1].replace("convManyAR", "convP") + ", landscape_only=False)" ) except Exception: args = convert_release_dust3r_args(ckpt["args"]) net = eval( args.model[:-1].replace("convManyAR", "convP") + ", landscape_only=False)" ) ckpt["model"] = { k.replace("_downstream_head", "downstream_head"): v for k, v in ckpt["model"].items() } print(net.load_state_dict(ckpt["model"], strict=False)) return net.to(device) def cat_cls(cls, tokens, pos): tokens = torch.cat((cls, tokens), dim=1) pos = torch.cat((-pos.new_ones(len(cls), 1, 2), pos), dim=1) return tokens, pos class Pow3RWrapper(torch.nn.Module): def __init__( self, name, ckpt_path, geometric_input_config, **kwargs, ): super().__init__() self.name = name self.ckpt_path = ckpt_path self.geometric_input_config = geometric_input_config # Init the model and load the checkpoint print(f"Loading checkpoint from {self.ckpt_path} ...") ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False) model = ckpt["definition"] print(f"Creating model = {model}") self.model = eval(model) print(self.model.load_state_dict(ckpt["weights"])) def forward(self, views): """ Forward pass wrapper for Pow3R. Assumption: - The number of input views is 2. Args: views (List[dict]): List of dictionaries containing the input views' images and instance information. Length of the list should be 2. Each dictionary should contain the following keys: "img" (tensor): Image tensor of shape (B, C, H, W). "data_norm_type" (list): ["dust3r"] Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs: "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3). "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation. "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1). Returns: List[dict]: A list containing the final outputs for the two views. Length of the list will be 2. """ # Check that the number of input views is 2 assert len(views) == 2, "Pow3R requires 2 input views." # Check the data norm type data_norm_type = views[0]["data_norm_type"][0] assert data_norm_type == "dust3r", ( "Pow3R expects a normalized image with the DUSt3R normalization scheme applied" ) # Get the batch size per view, device and two views batch_size_per_view = views[0]["img"].shape[0] device = views[0]["img"].device view1, view2 = views # Decide if we need to use the geometric inputs if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]: # Decide if we need to use the camera intrinsics if ( torch.rand(1, device=device) < self.geometric_input_config["ray_dirs_prob"] ): add_intrinsics(view1, view1.get("camera_intrinsics")) add_intrinsics(view2, view2.get("camera_intrinsics")) # Decide if we need to use the depth map if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]: depthmap1 = view1.get("depthmap") depthmap2 = view2.get("depthmap") if depthmap1 is not None: depthmap1 = depthmap1.squeeze(-1).to(device) if depthmap2 is not None: depthmap2 = depthmap2.squeeze(-1).to(device) add_depth(view1, depthmap1) add_depth(view2, depthmap2) # Decide if we need to use the camera pose if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]: cam1 = view1.get("camera_pose") cam2 = view2.get("camera_pose") add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1) add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1) # Get the model predictions preds = self.model(view1, view2) # Convert the output to MapAnything format with torch.autocast("cuda", enabled=False): res = [] for view_idx in range(2): # Get the model predictions for the current view curr_view_pred = preds[view_idx] # For the first view if view_idx == 0: # Get the global frame and camera frame pointmaps global_pts = curr_view_pred["pts3d"] cam_pts = curr_view_pred["pts3d"] conf = curr_view_pred["conf"] # Get the ray directions and depth along ray depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True) ray_directions = cam_pts / depth_along_ray # Initalize identity camera pose cam_rot = torch.eye(3, device=device) cam_quat = mat_to_quat(cam_rot) cam_trans = torch.zeros(3, device=device) cam_quat = cam_quat.unsqueeze(0).repeat(batch_size_per_view, 1) cam_trans = cam_trans.unsqueeze(0).repeat(batch_size_per_view, 1) # For the second view elif view_idx == 1: # Get the global frame and camera frame pointmaps pred_global_pts = curr_view_pred["pts3d_in_other_view"] cam_pts = curr_view_pred["pts3d2"] conf = (curr_view_pred["conf"] * curr_view_pred["conf2"]).sqrt() # Get the ray directions and depth along ray depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True) ray_directions = cam_pts / depth_along_ray # Compute the camera pose using the pointmaps cam_rot, cam_trans, scale = roma.rigid_points_registration( cam_pts.reshape(batch_size_per_view, -1, 3), pred_global_pts.reshape(batch_size_per_view, -1, 3), weights=conf.reshape(batch_size_per_view, -1), compute_scaling=True, ) cam_quat = mat_to_quat(cam_rot) # Scale the predicted camera frame pointmap and compute the new global frame pointmap cam_pts = scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * cam_pts global_pts = cam_pts.reshape( batch_size_per_view, -1, 3 ) @ cam_rot.permute(0, 2, 1) + cam_trans.unsqueeze(1) global_pts = global_pts.view(pred_global_pts.shape) # Append the result in MapAnything format res.append( { "pts3d": global_pts, "pts3d_cam": cam_pts, "ray_directions": ray_directions, "depth_along_ray": depth_along_ray, "cam_trans": cam_trans, "cam_quats": cam_quat, "conf": conf, } ) return res class Pow3RBAWrapper(torch.nn.Module): def __init__( self, name, ckpt_path, geometric_input_config, scene_graph="complete", inference_batch_size=32, global_optim_schedule="cosine", global_optim_lr=0.01, global_optim_niter=300, **kwargs, ): super().__init__() self.name = name self.ckpt_path = ckpt_path self.geometric_input_config = geometric_input_config self.scene_graph = scene_graph self.inference_batch_size = inference_batch_size self.global_optim_schedule = global_optim_schedule self.global_optim_lr = global_optim_lr self.global_optim_niter = global_optim_niter # Init the model and load the checkpoint print(f"Loading checkpoint from {self.ckpt_path} ...") ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False) model = ckpt["definition"] print(f"Creating model = {model}") self.model = eval(model) print(self.model.load_state_dict(ckpt["weights"])) # Init the global aligner mode self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer def infer_two_views(self, views): """ Wrapper for Pow3R 2-View inference. Assumption: - The number of input views is 2. Args: views (List[dict]): List of dictionaries containing the input views' images and instance information. Length of the list should be 2. Each dictionary should contain the following keys: "img" (tensor): Image tensor of shape (B, C, H, W). "data_norm_type" (list): ["dust3r"] Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs: "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3). "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation. "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1). Returns: List[dict]: A list containing the final outputs for the two views. Length of the list will be 2. """ # Check that the number of input views is 2 assert len(views) == 2, "Pow3R requires 2 input views." # Check the data norm type data_norm_type = views[0]["data_norm_type"][0] assert data_norm_type == "dust3r", ( "Pow3R expects a normalized image with the DUSt3R normalization scheme applied" ) # Get the device and two views device = views[0]["img"].device view1, view2 = views # Decide if we need to use the geometric inputs if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]: # Decide if we need to use the camera intrinsics if ( torch.rand(1, device=device) < self.geometric_input_config["ray_dirs_prob"] ): add_intrinsics(view1, view1.get("camera_intrinsics")) add_intrinsics(view2, view2.get("camera_intrinsics")) # Decide if we need to use the depth map if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]: depthmap1 = view1.get("depthmap") depthmap2 = view2.get("depthmap") if depthmap1 is not None: depthmap1 = depthmap1.squeeze(-1).to(device) if depthmap2 is not None: depthmap2 = depthmap2.squeeze(-1).to(device) add_depth(view1, depthmap1) add_depth(view2, depthmap2) # Decide if we need to use the camera pose if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]: cam1 = view1.get("camera_pose") cam2 = view2.get("camera_pose") add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1) add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1) # Get the model predictions preds = self.model(view1, view2) return preds def loss_of_one_batch(self, batch, device): """ Compute prediction for two views. """ view1, view2 = batch ignore_keys = set( [ "dataset", "label", "instance", "idx", "true_shape", "rng", "name", "data_norm_type", ] ) for view in batch: for name in view.keys(): # pseudo_focal if name in ignore_keys: continue view[name] = view[name].to(device, non_blocking=True) pred1, pred2 = self.infer_two_views([view1, view2]) result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2) return result @torch.no_grad() def inference(self, pairs, device, verbose=False): """ Wrapper for multi-pair inference using Pow3R. """ if verbose: print(f">> Inference with model on {len(pairs)} image pairs") result = [] multiple_shapes = not (check_if_same_size(pairs)) if multiple_shapes: self.inference_batch_size = 1 for i in tqdm.trange( 0, len(pairs), self.inference_batch_size, disable=not verbose ): res = self.loss_of_one_batch( collate_with_cat(pairs[i : i + self.inference_batch_size]), device ) result.append(to_cpu(res)) result = collate_with_cat(result, lists=multiple_shapes) return result def forward(self, views): """ Forward pass wrapper for Pow3R using the global aligner. Assumption: - The batch size of input views is 1. Args: views (List[dict]): List of dictionaries containing the input views' images and instance information. Each dictionary should contain the following keys, where B is the batch size and is 1: "img" (tensor): Image tensor of shape (B, C, H, W). "data_norm_type" (list): ["dust3r"] Returns: List[dict]: A list containing the final outputs for the input views. """ # Check the batch size of input views batch_size_per_view, _, height, width = views[0]["img"].shape device = views[0]["img"].device num_views = len(views) assert batch_size_per_view == 1, ( f"Batch size of input views should be 1, but got {batch_size_per_view}." ) # Check the data norm type data_norm_type = views[0]["data_norm_type"][0] assert data_norm_type == "dust3r", ( "Pow3R-BA expects a normalized image with the DUSt3R normalization scheme applied" ) # Convert the input views to the expected input format images = [] for view in views: images.append( dict( img=view["img"], camera_intrinsics=view["camera_intrinsics"], depthmap=view["depthmap"], camera_pose=view["camera_pose"], data_norm_type=view["data_norm_type"], true_shape=view["true_shape"], idx=len(images), instance=str(len(images)), ) ) # Make image pairs and run inference pair-wise pairs = make_pairs( images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True ) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=FutureWarning) output = self.inference( pairs, device, verbose=False, ) # Global optimization with torch.enable_grad(): scene = global_aligner( output, device=device, mode=self.global_aligner_mode, verbose=False ) _ = scene.compute_global_alignment( init="mst", niter=self.global_optim_niter, schedule=self.global_optim_schedule, lr=self.global_optim_lr, ) # Make sure scene is not None if scene is None: raise RuntimeError("Global optimization failed.") # Get the predictions intrinsics = scene.get_intrinsics() c2w_poses = scene.get_im_poses() depths = scene.get_depthmaps() # Convert the output to the MapAnything format with torch.autocast("cuda", enabled=False): res = [] for view_idx in range(num_views): # Get the current view predictions curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0) curr_view_pose = c2w_poses[view_idx].unsqueeze(0) curr_view_depth_z = depths[view_idx].unsqueeze(0) # Convert the pose to quaternions and translation curr_view_cam_translations = curr_view_pose[..., :3, 3] curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) # Get the camera frame pointmaps curr_view_pts3d_cam, _ = depthmap_to_camera_frame( curr_view_depth_z, curr_view_intrinsic ) # Convert the z depth to depth along ray curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( curr_view_depth_z, curr_view_intrinsic ) curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) # Get the ray directions on the unit sphere in the camera frame _, curr_view_ray_dirs = get_rays_in_camera_frame( curr_view_intrinsic, height, width, normalize_to_unit_sphere=True ) # Get the pointmaps curr_view_pts3d = ( convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( curr_view_ray_dirs, curr_view_depth_along_ray, curr_view_cam_translations, curr_view_cam_quats, ) ) # Append the outputs to the result list res.append( { "pts3d": curr_view_pts3d, "pts3d_cam": curr_view_pts3d_cam, "ray_directions": curr_view_ray_dirs, "depth_along_ray": curr_view_depth_along_ray, "cam_trans": curr_view_cam_translations, "cam_quats": curr_view_cam_quats, } ) return res