Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| """ | |
| Inference wrapper for Pow3R | |
| """ | |
| import warnings | |
| from copy import deepcopy | |
| import pow3r.model.blocks # noqa | |
| import roma | |
| import torch | |
| import torch.nn as nn | |
| import tqdm | |
| from dust3r.cloud_opt import global_aligner, GlobalAlignerMode | |
| from dust3r.image_pairs import make_pairs | |
| from dust3r.inference import check_if_same_size | |
| from dust3r.model import CroCoNet | |
| from dust3r.patch_embed import get_patch_embed as dust3r_patch_embed | |
| from dust3r.utils.device import collate_with_cat, to_cpu | |
| from dust3r.utils.misc import ( | |
| fill_default_args, | |
| freeze_all_params, | |
| interleave, | |
| is_symmetrized, | |
| transpose_to_landscape, | |
| ) | |
| from pow3r.model.blocks import Block, BlockInject, DecoderBlock, DecoderBlockInject, Mlp | |
| from pow3r.model.heads import head_factory | |
| from pow3r.model.inference import ( | |
| add_depth, | |
| add_intrinsics, | |
| add_relpose, | |
| ) | |
| from pow3r.model.patch_embed import get_patch_embed | |
| from mapanything.models.external.vggt.utils.rotation import mat_to_quat | |
| from mapanything.utils.geometry import ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, | |
| convert_z_depth_to_depth_along_ray, | |
| depthmap_to_camera_frame, | |
| get_rays_in_camera_frame, | |
| ) | |
| class Pow3R(CroCoNet): | |
| """Two siamese encoders, followed by two decoders. | |
| The goal is to output 3d points directly, both images in view1's frame | |
| (hence the asymmetry). | |
| """ | |
| def __init__( | |
| self, | |
| mode="embed", | |
| head_type="linear", | |
| patch_embed_cls="PatchEmbedDust3R", | |
| freeze="none", | |
| landscape_only=True, | |
| **croco_kwargs, | |
| ): | |
| # retrieve all default arguments using python magic | |
| self.croco_args = fill_default_args(croco_kwargs, super().__init__) | |
| super().__init__(**croco_kwargs) | |
| del self.mask_token # useless | |
| del self.prediction_head | |
| dec_dim, enc_dim = self.decoder_embed.weight.shape | |
| self.enc_embed_dim = enc_dim | |
| self.dec_embed_dim = dec_dim | |
| self.mode = mode | |
| # additional parameters in the encoder | |
| img_size = self.patch_embed.img_size | |
| patch_size = self.patch_embed.patch_size[0] | |
| self.patch_embed = dust3r_patch_embed( | |
| patch_embed_cls, img_size, patch_size, self.enc_embed_dim | |
| ) | |
| self.patch_embed_rays = get_patch_embed( | |
| patch_embed_cls + "_Mlp", | |
| img_size, | |
| patch_size, | |
| self.enc_embed_dim, | |
| in_chans=3, | |
| ) | |
| self.patch_embed_depth = get_patch_embed( | |
| patch_embed_cls + "_Mlp", | |
| img_size, | |
| patch_size, | |
| self.enc_embed_dim, | |
| in_chans=2, | |
| ) | |
| self.pose_embed = Mlp(12, 4 * dec_dim, dec_dim) | |
| # additional parameters in the decoder | |
| self.dec_cls = "_cls" in self.mode | |
| self.dec_num_cls = 0 | |
| if self.dec_cls: | |
| # use a CLS token in the decoder only | |
| self.mode = self.mode.replace("_cls", "") | |
| self.cls_token1 = nn.Parameter(torch.zeros((dec_dim,))) | |
| self.cls_token2 = nn.Parameter(torch.zeros((dec_dim,))) | |
| self.dec_num_cls = 1 # affects all blocks | |
| use_ln = "_ln" in self.mode # TODO remove? | |
| self.patch_ln = nn.LayerNorm(enc_dim) if use_ln else nn.Identity() | |
| self.dec1_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity() | |
| self.dec2_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity() | |
| self.dec_blocks2 = deepcopy(self.dec_blocks) | |
| # here we modify some of the blocks | |
| self.replace_some_blocks() | |
| self.set_downstream_head(head_type, landscape_only, **croco_kwargs) | |
| self.set_freeze(freeze) | |
| def replace_some_blocks(self): | |
| assert self.mode.startswith("inject") # inject[0,0.5] | |
| NewBlock = BlockInject | |
| DecoderNewBlock = DecoderBlockInject | |
| all_layers = { | |
| i / n | |
| for i in range(len(self.enc_blocks)) | |
| for n in [len(self.enc_blocks), len(self.dec_blocks)] | |
| } | |
| which_layers = eval(self.mode[self.mode.find("[") :]) or all_layers | |
| assert isinstance(which_layers, (set, list)) | |
| n = 0 | |
| for i, block in enumerate(self.enc_blocks): | |
| if i / len(self.enc_blocks) in which_layers: | |
| block.__class__ = NewBlock | |
| block.init(self.enc_embed_dim) | |
| n += 1 | |
| else: | |
| block.__class__ = Block | |
| assert n == len(which_layers), breakpoint() | |
| n = 0 | |
| for i in range(len(self.dec_blocks)): | |
| for blocks in [self.dec_blocks, self.dec_blocks2]: | |
| block = blocks[i] | |
| if i / len(self.dec_blocks) in which_layers: | |
| block.__class__ = DecoderNewBlock | |
| block.init(self.dec_embed_dim) | |
| n += 1 | |
| else: | |
| block.__class__ = DecoderBlock | |
| assert n == 2 * len(which_layers), breakpoint() | |
| def from_pretrained(cls, pretrained_model_path, **kw): | |
| return _load_model(pretrained_model_path, device="cpu") | |
| def load_state_dict(self, ckpt, **kw): | |
| # duplicate all weights for the second decoder if not present | |
| new_ckpt = dict(ckpt) | |
| if not any(k.startswith("dec_blocks2") for k in ckpt): | |
| for key, value in ckpt.items(): | |
| if key.startswith("dec_blocks"): | |
| new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value | |
| # remove layers that have different shapes | |
| cur_ckpt = self.state_dict() | |
| for key, val in ckpt.items(): | |
| if key.startswith("downstream_head2.proj"): | |
| if key in cur_ckpt and cur_ckpt[key].shape != val.shape: | |
| print(f" (removing ckpt[{key}] because wrong shape)") | |
| del new_ckpt[key] | |
| return super().load_state_dict(new_ckpt, **kw) | |
| def set_freeze(self, freeze): # this is for use by downstream models | |
| self.freeze = freeze | |
| to_be_frozen = { | |
| "none": [], | |
| "encoder": [self.patch_embed, self.enc_blocks], | |
| } | |
| freeze_all_params(to_be_frozen[freeze]) | |
| def set_prediction_head(self, *args, **kwargs): | |
| """No prediction head""" | |
| return | |
| def set_downstream_head( | |
| self, | |
| head_type, | |
| landscape_only, | |
| patch_size, | |
| img_size, | |
| mlp_ratio, | |
| dec_depth, | |
| **kw, | |
| ): | |
| assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, ( | |
| f"{img_size=} must be multiple of {patch_size=}" | |
| ) | |
| # split heads if different | |
| heads = head_type.split(";") | |
| assert len(heads) in (1, 2) | |
| head1_type, head2_type = (heads + heads)[:2] | |
| # allocate heads | |
| self.downstream_head1 = head_factory(head1_type, self) | |
| self.downstream_head2 = head_factory(head2_type, self) | |
| # magic wrapper | |
| self.head1 = transpose_to_landscape( | |
| self.downstream_head1, activate=landscape_only | |
| ) | |
| self.head2 = transpose_to_landscape( | |
| self.downstream_head2, activate=landscape_only | |
| ) | |
| def _encode_image(self, image, true_shape, rays=None, depth=None): | |
| # embed the image into patches (x has size B x Npatches x C) | |
| x, pos = self.patch_embed(image, true_shape=true_shape) | |
| if rays is not None: # B,3,H,W | |
| rays_emb, pos2 = self.patch_embed_rays(rays, true_shape=true_shape) | |
| assert (pos == pos2).all() | |
| if self.mode.startswith("embed"): | |
| x = x + rays_emb | |
| else: | |
| rays_emb = None | |
| if depth is not None: # B,2,H,W | |
| depth_emb, pos2 = self.patch_embed_depth(depth, true_shape=true_shape) | |
| assert (pos == pos2).all() | |
| if self.mode.startswith("embed"): | |
| x = x + depth_emb | |
| else: | |
| depth_emb = None | |
| x = self.patch_ln(x) | |
| # add positional embedding without cls token | |
| assert self.enc_pos_embed is None | |
| # now apply the transformer encoder and normalization | |
| for blk in self.enc_blocks: | |
| x = blk(x, pos, rays=rays_emb, depth=depth_emb) | |
| x = self.enc_norm(x) | |
| return x, pos | |
| def encode_symmetrized(self, view1, view2): | |
| img1 = view1["img"] | |
| img2 = view2["img"] | |
| B = img1.shape[0] | |
| # Recover true_shape when available, otherwise assume that the img shape is the true one | |
| shape1 = view1.get( | |
| "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1) | |
| ) | |
| shape2 = view2.get( | |
| "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1) | |
| ) | |
| # warning! maybe the images have different portrait/landscape orientations | |
| # privileged information | |
| rays1 = view1.get("known_rays", None) | |
| rays2 = view2.get("known_rays", None) | |
| depth1 = view1.get("known_depth", None) | |
| depth2 = view2.get("known_depth", None) | |
| if is_symmetrized(view1, view2): | |
| # computing half of forward pass!' | |
| def hsub(x): | |
| return None if x is None else x[::2] | |
| feat1, pos1 = self._encode_image( | |
| img1[::2], shape1[::2], rays=hsub(rays1), depth=hsub(depth1) | |
| ) | |
| feat2, pos2 = self._encode_image( | |
| img2[::2], shape2[::2], rays=hsub(rays2), depth=hsub(depth2) | |
| ) | |
| feat1, feat2 = interleave(feat1, feat2) | |
| pos1, pos2 = interleave(pos1, pos2) | |
| else: | |
| feat1, pos1 = self._encode_image(img1, shape1, rays=rays1, depth=depth1) | |
| feat2, pos2 = self._encode_image(img2, shape2, rays=rays2, depth=depth2) | |
| return (shape1, shape2), (feat1, feat2), (pos1, pos2) | |
| def _decoder(self, f1, pos1, f2, pos2, relpose1=None, relpose2=None): | |
| final_output = [(f1, f2)] # before projection | |
| # project to decoder dim | |
| f1 = self.decoder_embed(f1) | |
| f2 = self.decoder_embed(f2) | |
| # add CLS token for the decoder | |
| if self.dec_cls: | |
| cls1 = self.cls_token1[None, None].expand(len(f1), 1, -1).clone() | |
| cls2 = self.cls_token2[None, None].expand(len(f2), 1, -1).clone() | |
| if relpose1 is not None: # shape = (B, 4, 4) | |
| pose_emb1 = self.pose_embed(relpose1[:, :3].flatten(1)).unsqueeze(1) | |
| if self.mode.startswith("embed"): | |
| if self.dec_cls: | |
| cls1 = cls1 + pose_emb1 | |
| else: | |
| f1 = f1 + pose_emb1 | |
| else: | |
| pose_emb1 = None | |
| if relpose2 is not None: # shape = (B, 4, 4) | |
| pose_emb2 = self.pose_embed(relpose2[:, :3].flatten(1)).unsqueeze(1) | |
| if self.mode.startswith("embed"): | |
| if self.dec_cls: | |
| cls2 = cls2 + pose_emb2 | |
| else: | |
| f2 = f2 + pose_emb2 | |
| else: | |
| pose_emb2 = None | |
| if self.dec_cls: | |
| f1, pos1 = cat_cls(cls1, f1, pos1) | |
| f2, pos2 = cat_cls(cls2, f2, pos2) | |
| f1 = self.dec1_pre_ln(f1) | |
| f2 = self.dec2_pre_ln(f2) | |
| final_output.append((f1, f2)) # to be removed later | |
| for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): | |
| # img1 side | |
| f1, _ = blk1( | |
| *final_output[-1][::+1], | |
| pos1, | |
| pos2, | |
| relpose=pose_emb1, | |
| num_cls=self.dec_num_cls, | |
| ) | |
| # img2 side | |
| f2, _ = blk2( | |
| *final_output[-1][::-1], | |
| pos2, | |
| pos1, | |
| relpose=pose_emb2, | |
| num_cls=self.dec_num_cls, | |
| ) | |
| # store the result | |
| final_output.append((f1, f2)) | |
| del final_output[1] # duplicate with final_output[0] (after decoder proj) | |
| if self.dec_cls: # remove cls token for decoder layers | |
| final_output[1:] = [(f1[:, 1:], f2[:, 1:]) for f1, f2 in final_output[1:]] | |
| # normalize last output | |
| final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) | |
| return zip(*final_output) | |
| def _downstream_head(self, head_num, decout, img_shape): | |
| B, S, D = decout[-1].shape | |
| head = getattr(self, f"head{head_num}") | |
| return head(decout, img_shape) | |
| def forward(self, view1, view2): | |
| # encode the two images --> B,S,D | |
| (shape1, shape2), (feat1, feat2), (pos1, pos2) = self.encode_symmetrized( | |
| view1, view2 | |
| ) | |
| # combine all ref images into object-centric representation | |
| dec1, dec2 = self._decoder( | |
| feat1, | |
| pos1, | |
| feat2, | |
| pos2, | |
| relpose1=view1.get("known_pose"), | |
| relpose2=view2.get("known_pose"), | |
| ) | |
| with torch.autocast("cuda", enabled=False): | |
| res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) | |
| res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) | |
| res2["pts3d_in_other_view"] = res2.pop( | |
| "pts3d" | |
| ) # predict view2's pts3d in view1's frame | |
| return res1, res2 | |
| def convert_release_dust3r_args(args): | |
| args.model = ( | |
| args.model.replace("patch_embed_cls", "patch_embed") | |
| .replace("AsymmetricMASt3R", "AsymmetricCroCo3DStereo") | |
| .replace("PatchEmbedDust3R", "convManyAR") | |
| .replace( | |
| "pos_embed='RoPE100'", | |
| "enc_pos_embed='cuRoPE100', dec_pos_embed='cuRoPE100'", | |
| ) | |
| ) | |
| return args | |
| def _load_model(model_path, device): | |
| print("... loading model from", model_path) | |
| ckpt = torch.load(model_path, map_location="cpu") | |
| try: | |
| net = eval( | |
| ckpt["args"].model[:-1].replace("convManyAR", "convP") | |
| + ", landscape_only=False)" | |
| ) | |
| except Exception: | |
| args = convert_release_dust3r_args(ckpt["args"]) | |
| net = eval( | |
| args.model[:-1].replace("convManyAR", "convP") + ", landscape_only=False)" | |
| ) | |
| ckpt["model"] = { | |
| k.replace("_downstream_head", "downstream_head"): v | |
| for k, v in ckpt["model"].items() | |
| } | |
| print(net.load_state_dict(ckpt["model"], strict=False)) | |
| return net.to(device) | |
| def cat_cls(cls, tokens, pos): | |
| tokens = torch.cat((cls, tokens), dim=1) | |
| pos = torch.cat((-pos.new_ones(len(cls), 1, 2), pos), dim=1) | |
| return tokens, pos | |
| class Pow3RWrapper(torch.nn.Module): | |
| def __init__( | |
| self, | |
| name, | |
| ckpt_path, | |
| geometric_input_config, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.name = name | |
| self.ckpt_path = ckpt_path | |
| self.geometric_input_config = geometric_input_config | |
| # Init the model and load the checkpoint | |
| print(f"Loading checkpoint from {self.ckpt_path} ...") | |
| ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False) | |
| model = ckpt["definition"] | |
| print(f"Creating model = {model}") | |
| self.model = eval(model) | |
| print(self.model.load_state_dict(ckpt["weights"])) | |
| def forward(self, views): | |
| """ | |
| Forward pass wrapper for Pow3R. | |
| Assumption: | |
| - The number of input views is 2. | |
| Args: | |
| views (List[dict]): List of dictionaries containing the input views' images and instance information. | |
| Length of the list should be 2. | |
| Each dictionary should contain the following keys: | |
| "img" (tensor): Image tensor of shape (B, C, H, W). | |
| "data_norm_type" (list): ["dust3r"] | |
| Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs: | |
| "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3). | |
| "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation. | |
| "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1). | |
| Returns: | |
| List[dict]: A list containing the final outputs for the two views. Length of the list will be 2. | |
| """ | |
| # Check that the number of input views is 2 | |
| assert len(views) == 2, "Pow3R requires 2 input views." | |
| # Check the data norm type | |
| data_norm_type = views[0]["data_norm_type"][0] | |
| assert data_norm_type == "dust3r", ( | |
| "Pow3R expects a normalized image with the DUSt3R normalization scheme applied" | |
| ) | |
| # Get the batch size per view, device and two views | |
| batch_size_per_view = views[0]["img"].shape[0] | |
| device = views[0]["img"].device | |
| view1, view2 = views | |
| # Decide if we need to use the geometric inputs | |
| if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]: | |
| # Decide if we need to use the camera intrinsics | |
| if ( | |
| torch.rand(1, device=device) | |
| < self.geometric_input_config["ray_dirs_prob"] | |
| ): | |
| add_intrinsics(view1, view1.get("camera_intrinsics")) | |
| add_intrinsics(view2, view2.get("camera_intrinsics")) | |
| # Decide if we need to use the depth map | |
| if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]: | |
| depthmap1 = view1.get("depthmap") | |
| depthmap2 = view2.get("depthmap") | |
| if depthmap1 is not None: | |
| depthmap1 = depthmap1.squeeze(-1).to(device) | |
| if depthmap2 is not None: | |
| depthmap2 = depthmap2.squeeze(-1).to(device) | |
| add_depth(view1, depthmap1) | |
| add_depth(view2, depthmap2) | |
| # Decide if we need to use the camera pose | |
| if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]: | |
| cam1 = view1.get("camera_pose") | |
| cam2 = view2.get("camera_pose") | |
| add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1) | |
| add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1) | |
| # Get the model predictions | |
| preds = self.model(view1, view2) | |
| # Convert the output to MapAnything format | |
| with torch.autocast("cuda", enabled=False): | |
| res = [] | |
| for view_idx in range(2): | |
| # Get the model predictions for the current view | |
| curr_view_pred = preds[view_idx] | |
| # For the first view | |
| if view_idx == 0: | |
| # Get the global frame and camera frame pointmaps | |
| global_pts = curr_view_pred["pts3d"] | |
| cam_pts = curr_view_pred["pts3d"] | |
| conf = curr_view_pred["conf"] | |
| # Get the ray directions and depth along ray | |
| depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True) | |
| ray_directions = cam_pts / depth_along_ray | |
| # Initalize identity camera pose | |
| cam_rot = torch.eye(3, device=device) | |
| cam_quat = mat_to_quat(cam_rot) | |
| cam_trans = torch.zeros(3, device=device) | |
| cam_quat = cam_quat.unsqueeze(0).repeat(batch_size_per_view, 1) | |
| cam_trans = cam_trans.unsqueeze(0).repeat(batch_size_per_view, 1) | |
| # For the second view | |
| elif view_idx == 1: | |
| # Get the global frame and camera frame pointmaps | |
| pred_global_pts = curr_view_pred["pts3d_in_other_view"] | |
| cam_pts = curr_view_pred["pts3d2"] | |
| conf = (curr_view_pred["conf"] * curr_view_pred["conf2"]).sqrt() | |
| # Get the ray directions and depth along ray | |
| depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True) | |
| ray_directions = cam_pts / depth_along_ray | |
| # Compute the camera pose using the pointmaps | |
| cam_rot, cam_trans, scale = roma.rigid_points_registration( | |
| cam_pts.reshape(batch_size_per_view, -1, 3), | |
| pred_global_pts.reshape(batch_size_per_view, -1, 3), | |
| weights=conf.reshape(batch_size_per_view, -1), | |
| compute_scaling=True, | |
| ) | |
| cam_quat = mat_to_quat(cam_rot) | |
| # Scale the predicted camera frame pointmap and compute the new global frame pointmap | |
| cam_pts = scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * cam_pts | |
| global_pts = cam_pts.reshape( | |
| batch_size_per_view, -1, 3 | |
| ) @ cam_rot.permute(0, 2, 1) + cam_trans.unsqueeze(1) | |
| global_pts = global_pts.view(pred_global_pts.shape) | |
| # Append the result in MapAnything format | |
| res.append( | |
| { | |
| "pts3d": global_pts, | |
| "pts3d_cam": cam_pts, | |
| "ray_directions": ray_directions, | |
| "depth_along_ray": depth_along_ray, | |
| "cam_trans": cam_trans, | |
| "cam_quats": cam_quat, | |
| "conf": conf, | |
| } | |
| ) | |
| return res | |
| class Pow3RBAWrapper(torch.nn.Module): | |
| def __init__( | |
| self, | |
| name, | |
| ckpt_path, | |
| geometric_input_config, | |
| scene_graph="complete", | |
| inference_batch_size=32, | |
| global_optim_schedule="cosine", | |
| global_optim_lr=0.01, | |
| global_optim_niter=300, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.name = name | |
| self.ckpt_path = ckpt_path | |
| self.geometric_input_config = geometric_input_config | |
| self.scene_graph = scene_graph | |
| self.inference_batch_size = inference_batch_size | |
| self.global_optim_schedule = global_optim_schedule | |
| self.global_optim_lr = global_optim_lr | |
| self.global_optim_niter = global_optim_niter | |
| # Init the model and load the checkpoint | |
| print(f"Loading checkpoint from {self.ckpt_path} ...") | |
| ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False) | |
| model = ckpt["definition"] | |
| print(f"Creating model = {model}") | |
| self.model = eval(model) | |
| print(self.model.load_state_dict(ckpt["weights"])) | |
| # Init the global aligner mode | |
| self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer | |
| def infer_two_views(self, views): | |
| """ | |
| Wrapper for Pow3R 2-View inference. | |
| Assumption: | |
| - The number of input views is 2. | |
| Args: | |
| views (List[dict]): List of dictionaries containing the input views' images and instance information. | |
| Length of the list should be 2. | |
| Each dictionary should contain the following keys: | |
| "img" (tensor): Image tensor of shape (B, C, H, W). | |
| "data_norm_type" (list): ["dust3r"] | |
| Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs: | |
| "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3). | |
| "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation. | |
| "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1). | |
| Returns: | |
| List[dict]: A list containing the final outputs for the two views. Length of the list will be 2. | |
| """ | |
| # Check that the number of input views is 2 | |
| assert len(views) == 2, "Pow3R requires 2 input views." | |
| # Check the data norm type | |
| data_norm_type = views[0]["data_norm_type"][0] | |
| assert data_norm_type == "dust3r", ( | |
| "Pow3R expects a normalized image with the DUSt3R normalization scheme applied" | |
| ) | |
| # Get the device and two views | |
| device = views[0]["img"].device | |
| view1, view2 = views | |
| # Decide if we need to use the geometric inputs | |
| if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]: | |
| # Decide if we need to use the camera intrinsics | |
| if ( | |
| torch.rand(1, device=device) | |
| < self.geometric_input_config["ray_dirs_prob"] | |
| ): | |
| add_intrinsics(view1, view1.get("camera_intrinsics")) | |
| add_intrinsics(view2, view2.get("camera_intrinsics")) | |
| # Decide if we need to use the depth map | |
| if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]: | |
| depthmap1 = view1.get("depthmap") | |
| depthmap2 = view2.get("depthmap") | |
| if depthmap1 is not None: | |
| depthmap1 = depthmap1.squeeze(-1).to(device) | |
| if depthmap2 is not None: | |
| depthmap2 = depthmap2.squeeze(-1).to(device) | |
| add_depth(view1, depthmap1) | |
| add_depth(view2, depthmap2) | |
| # Decide if we need to use the camera pose | |
| if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]: | |
| cam1 = view1.get("camera_pose") | |
| cam2 = view2.get("camera_pose") | |
| add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1) | |
| add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1) | |
| # Get the model predictions | |
| preds = self.model(view1, view2) | |
| return preds | |
| def loss_of_one_batch(self, batch, device): | |
| """ | |
| Compute prediction for two views. | |
| """ | |
| view1, view2 = batch | |
| ignore_keys = set( | |
| [ | |
| "dataset", | |
| "label", | |
| "instance", | |
| "idx", | |
| "true_shape", | |
| "rng", | |
| "name", | |
| "data_norm_type", | |
| ] | |
| ) | |
| for view in batch: | |
| for name in view.keys(): # pseudo_focal | |
| if name in ignore_keys: | |
| continue | |
| view[name] = view[name].to(device, non_blocking=True) | |
| pred1, pred2 = self.infer_two_views([view1, view2]) | |
| result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2) | |
| return result | |
| def inference(self, pairs, device, verbose=False): | |
| """ | |
| Wrapper for multi-pair inference using Pow3R. | |
| """ | |
| if verbose: | |
| print(f">> Inference with model on {len(pairs)} image pairs") | |
| result = [] | |
| multiple_shapes = not (check_if_same_size(pairs)) | |
| if multiple_shapes: | |
| self.inference_batch_size = 1 | |
| for i in tqdm.trange( | |
| 0, len(pairs), self.inference_batch_size, disable=not verbose | |
| ): | |
| res = self.loss_of_one_batch( | |
| collate_with_cat(pairs[i : i + self.inference_batch_size]), device | |
| ) | |
| result.append(to_cpu(res)) | |
| result = collate_with_cat(result, lists=multiple_shapes) | |
| return result | |
| def forward(self, views): | |
| """ | |
| Forward pass wrapper for Pow3R using the global aligner. | |
| Assumption: | |
| - The batch size of input views is 1. | |
| Args: | |
| views (List[dict]): List of dictionaries containing the input views' images and instance information. | |
| Each dictionary should contain the following keys, where B is the batch size and is 1: | |
| "img" (tensor): Image tensor of shape (B, C, H, W). | |
| "data_norm_type" (list): ["dust3r"] | |
| Returns: | |
| List[dict]: A list containing the final outputs for the input views. | |
| """ | |
| # Check the batch size of input views | |
| batch_size_per_view, _, height, width = views[0]["img"].shape | |
| device = views[0]["img"].device | |
| num_views = len(views) | |
| assert batch_size_per_view == 1, ( | |
| f"Batch size of input views should be 1, but got {batch_size_per_view}." | |
| ) | |
| # Check the data norm type | |
| data_norm_type = views[0]["data_norm_type"][0] | |
| assert data_norm_type == "dust3r", ( | |
| "Pow3R-BA expects a normalized image with the DUSt3R normalization scheme applied" | |
| ) | |
| # Convert the input views to the expected input format | |
| images = [] | |
| for view in views: | |
| images.append( | |
| dict( | |
| img=view["img"], | |
| camera_intrinsics=view["camera_intrinsics"], | |
| depthmap=view["depthmap"], | |
| camera_pose=view["camera_pose"], | |
| data_norm_type=view["data_norm_type"], | |
| true_shape=view["true_shape"], | |
| idx=len(images), | |
| instance=str(len(images)), | |
| ) | |
| ) | |
| # Make image pairs and run inference pair-wise | |
| pairs = make_pairs( | |
| images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True | |
| ) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore", category=FutureWarning) | |
| output = self.inference( | |
| pairs, | |
| device, | |
| verbose=False, | |
| ) | |
| # Global optimization | |
| with torch.enable_grad(): | |
| scene = global_aligner( | |
| output, device=device, mode=self.global_aligner_mode, verbose=False | |
| ) | |
| _ = scene.compute_global_alignment( | |
| init="mst", | |
| niter=self.global_optim_niter, | |
| schedule=self.global_optim_schedule, | |
| lr=self.global_optim_lr, | |
| ) | |
| # Make sure scene is not None | |
| if scene is None: | |
| raise RuntimeError("Global optimization failed.") | |
| # Get the predictions | |
| intrinsics = scene.get_intrinsics() | |
| c2w_poses = scene.get_im_poses() | |
| depths = scene.get_depthmaps() | |
| # Convert the output to the MapAnything format | |
| with torch.autocast("cuda", enabled=False): | |
| res = [] | |
| for view_idx in range(num_views): | |
| # Get the current view predictions | |
| curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0) | |
| curr_view_pose = c2w_poses[view_idx].unsqueeze(0) | |
| curr_view_depth_z = depths[view_idx].unsqueeze(0) | |
| # Convert the pose to quaternions and translation | |
| curr_view_cam_translations = curr_view_pose[..., :3, 3] | |
| curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3]) | |
| # Get the camera frame pointmaps | |
| curr_view_pts3d_cam, _ = depthmap_to_camera_frame( | |
| curr_view_depth_z, curr_view_intrinsic | |
| ) | |
| # Convert the z depth to depth along ray | |
| curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( | |
| curr_view_depth_z, curr_view_intrinsic | |
| ) | |
| curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) | |
| # Get the ray directions on the unit sphere in the camera frame | |
| _, curr_view_ray_dirs = get_rays_in_camera_frame( | |
| curr_view_intrinsic, height, width, normalize_to_unit_sphere=True | |
| ) | |
| # Get the pointmaps | |
| curr_view_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| curr_view_ray_dirs, | |
| curr_view_depth_along_ray, | |
| curr_view_cam_translations, | |
| curr_view_cam_quats, | |
| ) | |
| ) | |
| # Append the outputs to the result list | |
| res.append( | |
| { | |
| "pts3d": curr_view_pts3d, | |
| "pts3d_cam": curr_view_pts3d_cam, | |
| "ray_directions": curr_view_ray_dirs, | |
| "depth_along_ray": curr_view_depth_along_ray, | |
| "cam_trans": curr_view_cam_translations, | |
| "cam_quats": curr_view_cam_quats, | |
| } | |
| ) | |
| return res | |