Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| """ | |
| Inference wrapper for Pi3 | |
| """ | |
| import torch | |
| from mapanything.models.external.pi3.models.pi3 import Pi3 | |
| from mapanything.models.external.vggt.utils.rotation import mat_to_quat | |
| class Pi3Wrapper(torch.nn.Module): | |
| def __init__( | |
| self, | |
| name, | |
| torch_hub_force_reload, | |
| load_pretrained_weights=True, | |
| pos_type="rope100", | |
| decoder_size="large", | |
| ): | |
| super().__init__() | |
| self.name = name | |
| self.torch_hub_force_reload = torch_hub_force_reload | |
| if load_pretrained_weights: | |
| # Load pre-trained weights | |
| if not torch_hub_force_reload: | |
| # Initialize the Pi3 model from huggingface hub cache | |
| print("Loading Pi3 from huggingface cache ...") | |
| self.model = Pi3.from_pretrained( | |
| "yyfz233/Pi3", | |
| ) | |
| else: | |
| # Initialize the Pi3 model | |
| self.model = Pi3.from_pretrained("yyfz233/Pi3", force_download=True) | |
| else: | |
| # Load the Pi3 class | |
| self.model = Pi3( | |
| pos_type=pos_type, | |
| decoder_size=decoder_size, | |
| ) | |
| # Get the dtype for Pi3 inference | |
| # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) | |
| self.dtype = ( | |
| torch.bfloat16 | |
| if torch.cuda.get_device_capability()[0] >= 8 | |
| else torch.float16 | |
| ) | |
| def forward(self, views): | |
| """ | |
| Forward pass wrapper for Pi3 | |
| Assumption: | |
| - All the input views have the same image shape. | |
| Args: | |
| views (List[dict]): List of dictionaries containing the input views' images and instance information. | |
| Each dictionary should contain the following keys: | |
| "img" (tensor): Image tensor of shape (B, C, H, W). | |
| "data_norm_type" (list): ["identity"] | |
| Returns: | |
| List[dict]: A list containing the final outputs for all N views. | |
| """ | |
| # Get input shape of the images, number of views, and batch size per view | |
| batch_size_per_view, _, height, width = views[0]["img"].shape | |
| num_views = len(views) | |
| # Check the data norm type | |
| # Pi3 expects a normalized image but without the DINOv2 mean and std applied ("identity") | |
| data_norm_type = views[0]["data_norm_type"][0] | |
| assert data_norm_type == "identity", ( | |
| "Pi3 expects a normalized image but without the DINOv2 mean and std applied" | |
| ) | |
| # Concatenate the images to create a single (B, V, C, H, W) tensor | |
| img_list = [view["img"] for view in views] | |
| images = torch.stack(img_list, dim=1) | |
| # Run the Pi3 aggregator | |
| with torch.autocast("cuda", dtype=self.dtype): | |
| results = self.model(images) | |
| # Need high precision for transformations | |
| with torch.autocast("cuda", enabled=False): | |
| # Convert the output to MapAnything format | |
| res = [] | |
| for view_idx in range(num_views): | |
| # Get the extrinsics | |
| curr_view_extrinsic = results["camera_poses"][:, view_idx, ...] | |
| curr_view_cam_translations = curr_view_extrinsic[..., :3, 3] | |
| curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3]) | |
| # Get the depth along ray, ray directions, local point cloud & global point cloud | |
| curr_view_pts3d_cam = results["local_points"][:, view_idx, ...] | |
| curr_view_depth_along_ray = torch.norm( | |
| curr_view_pts3d_cam, dim=-1, keepdim=True | |
| ) | |
| curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray | |
| curr_view_pts3d = results["points"][:, view_idx, ...] | |
| # Get the confidence | |
| curr_view_confidence = results["conf"][:, view_idx, ...] | |
| # Append the outputs to the result list | |
| res.append( | |
| { | |
| "pts3d": curr_view_pts3d, | |
| "pts3d_cam": curr_view_pts3d_cam, | |
| "ray_directions": curr_view_ray_dirs, | |
| "depth_along_ray": curr_view_depth_along_ray, | |
| "cam_trans": curr_view_cam_translations, | |
| "cam_quats": curr_view_cam_quats, | |
| "conf": curr_view_confidence, | |
| } | |
| ) | |
| return res | |