Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| """ | |
| Inference wrapper for VGGT | |
| """ | |
| import torch | |
| from mapanything.models.external.vggt.models.vggt import VGGT | |
| from mapanything.models.external.vggt.utils.geometry import closed_form_inverse_se3 | |
| from mapanything.models.external.vggt.utils.pose_enc import pose_encoding_to_extri_intri | |
| from mapanything.models.external.vggt.utils.rotation import mat_to_quat | |
| from mapanything.utils.geometry import ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, | |
| convert_z_depth_to_depth_along_ray, | |
| depthmap_to_camera_frame, | |
| get_rays_in_camera_frame, | |
| ) | |
| class VGGTWrapper(torch.nn.Module): | |
| def __init__( | |
| self, | |
| name, | |
| torch_hub_force_reload, | |
| load_pretrained_weights=True, | |
| depth=24, | |
| num_heads=16, | |
| intermediate_layer_idx=[4, 11, 17, 23], | |
| load_custom_ckpt=False, | |
| custom_ckpt_path=None, | |
| ): | |
| super().__init__() | |
| self.name = name | |
| self.torch_hub_force_reload = torch_hub_force_reload | |
| self.load_custom_ckpt = load_custom_ckpt | |
| self.custom_ckpt_path = custom_ckpt_path | |
| if load_pretrained_weights: | |
| # Load pre-trained weights | |
| if not torch_hub_force_reload: | |
| # Initialize the 1B VGGT model from huggingface hub cache | |
| print("Loading facebook/VGGT-1B from huggingface cache ...") | |
| self.model = VGGT.from_pretrained( | |
| "facebook/VGGT-1B", | |
| ) | |
| else: | |
| # Initialize the 1B VGGT model | |
| print("Re-downloading facebook/VGGT-1B ...") | |
| self.model = VGGT.from_pretrained( | |
| "facebook/VGGT-1B", force_download=True | |
| ) | |
| else: | |
| # Load the VGGT class | |
| self.model = VGGT( | |
| depth=depth, | |
| num_heads=num_heads, | |
| intermediate_layer_idx=intermediate_layer_idx, | |
| ) | |
| # Get the dtype for VGGT inference | |
| # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) | |
| self.dtype = ( | |
| torch.bfloat16 | |
| if torch.cuda.get_device_capability()[0] >= 8 | |
| else torch.float16 | |
| ) | |
| # Load custom checkpoint if requested | |
| if self.load_custom_ckpt: | |
| print(f"Loading checkpoint from {self.custom_ckpt_path} ...") | |
| assert self.custom_ckpt_path is not None, ( | |
| "custom_ckpt_path must be provided if load_custom_ckpt is set to True" | |
| ) | |
| custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False) | |
| print(self.model.load_state_dict(custom_ckpt, strict=True)) | |
| del custom_ckpt # in case it occupies memory | |
| def forward(self, views): | |
| """ | |
| Forward pass wrapper for VGGT | |
| Assumption: | |
| - All the input views have the same image shape. | |
| Args: | |
| views (List[dict]): List of dictionaries containing the input views' images and instance information. | |
| Each dictionary should contain the following keys: | |
| "img" (tensor): Image tensor of shape (B, C, H, W). | |
| "data_norm_type" (list): ["identity"] | |
| Returns: | |
| List[dict]: A list containing the final outputs for all N views. | |
| """ | |
| # Get input shape of the images, number of views, and batch size per view | |
| batch_size_per_view, _, height, width = views[0]["img"].shape | |
| num_views = len(views) | |
| # Check the data norm type | |
| # VGGT expects a normalized image but without the DINOv2 mean and std applied ("identity") | |
| data_norm_type = views[0]["data_norm_type"][0] | |
| assert data_norm_type == "identity", ( | |
| "VGGT expects a normalized image but without the DINOv2 mean and std applied" | |
| ) | |
| # Concatenate the images to create a single (B, V, C, H, W) tensor | |
| img_list = [view["img"] for view in views] | |
| images = torch.stack(img_list, dim=1) | |
| # Run the VGGT aggregator | |
| with torch.autocast("cuda", dtype=self.dtype): | |
| aggregated_tokens_list, ps_idx = self.model.aggregator(images) | |
| # Run the Camera + Pose Branch of VGGT | |
| with torch.autocast("cuda", enabled=False): | |
| # Predict Cameras | |
| pose_enc = self.model.camera_head(aggregated_tokens_list)[-1] | |
| # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world) | |
| # Extrinsics Shape: (B, V, 3, 4) | |
| # Intrinsics Shape: (B, V, 3, 3) | |
| extrinsic, intrinsic = pose_encoding_to_extri_intri( | |
| pose_enc, images.shape[-2:] | |
| ) | |
| # Predict Depth Maps | |
| # Depth Shape: (B, V, H, W, 1) | |
| # Depth Confidence Shape: (B, V, H, W) | |
| depth_map, depth_conf = self.model.depth_head( | |
| aggregated_tokens_list, images, ps_idx | |
| ) | |
| # Convert the output to MapAnything format | |
| res = [] | |
| for view_idx in range(num_views): | |
| # Get the extrinsics, intrinsics, depth map for the current view | |
| curr_view_extrinsic = extrinsic[:, view_idx, ...] | |
| curr_view_extrinsic = closed_form_inverse_se3( | |
| curr_view_extrinsic | |
| ) # Convert to cam2world | |
| curr_view_intrinsic = intrinsic[:, view_idx, ...] | |
| curr_view_depth_z = depth_map[:, view_idx, ...] | |
| curr_view_depth_z = curr_view_depth_z.squeeze(-1) | |
| curr_view_confidence = depth_conf[:, view_idx, ...] | |
| # Get the camera frame pointmaps | |
| curr_view_pts3d_cam, _ = depthmap_to_camera_frame( | |
| curr_view_depth_z, curr_view_intrinsic | |
| ) | |
| # Convert the extrinsics to quaternions and translations | |
| curr_view_cam_translations = curr_view_extrinsic[..., :3, 3] | |
| curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3]) | |
| # Convert the z depth to depth along ray | |
| curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray( | |
| curr_view_depth_z, curr_view_intrinsic | |
| ) | |
| curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1) | |
| # Get the ray directions on the unit sphere in the camera frame | |
| _, curr_view_ray_dirs = get_rays_in_camera_frame( | |
| curr_view_intrinsic, height, width, normalize_to_unit_sphere=True | |
| ) | |
| # Get the pointmaps | |
| curr_view_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| curr_view_ray_dirs, | |
| curr_view_depth_along_ray, | |
| curr_view_cam_translations, | |
| curr_view_cam_quats, | |
| ) | |
| ) | |
| # Append the outputs to the result list | |
| res.append( | |
| { | |
| "pts3d": curr_view_pts3d, | |
| "pts3d_cam": curr_view_pts3d_cam, | |
| "ray_directions": curr_view_ray_dirs, | |
| "depth_along_ray": curr_view_depth_along_ray, | |
| "cam_trans": curr_view_cam_translations, | |
| "cam_quats": curr_view_cam_quats, | |
| "conf": curr_view_confidence, | |
| } | |
| ) | |
| return res | |