Spaces:
Sleeping
Sleeping
File size: 8,017 Bytes
b74998d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
"""
Inference wrapper for VGGT
"""
import torch
from mapanything.models.external.vggt.models.vggt import VGGT
from mapanything.models.external.vggt.utils.geometry import closed_form_inverse_se3
from mapanything.models.external.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from mapanything.models.external.vggt.utils.rotation import mat_to_quat
from mapanything.utils.geometry import (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
convert_z_depth_to_depth_along_ray,
depthmap_to_camera_frame,
get_rays_in_camera_frame,
)
class VGGTWrapper(torch.nn.Module):
def __init__(
self,
name,
torch_hub_force_reload,
load_pretrained_weights=True,
depth=24,
num_heads=16,
intermediate_layer_idx=[4, 11, 17, 23],
load_custom_ckpt=False,
custom_ckpt_path=None,
):
super().__init__()
self.name = name
self.torch_hub_force_reload = torch_hub_force_reload
self.load_custom_ckpt = load_custom_ckpt
self.custom_ckpt_path = custom_ckpt_path
if load_pretrained_weights:
# Load pre-trained weights
if not torch_hub_force_reload:
# Initialize the 1B VGGT model from huggingface hub cache
print("Loading facebook/VGGT-1B from huggingface cache ...")
self.model = VGGT.from_pretrained(
"facebook/VGGT-1B",
)
else:
# Initialize the 1B VGGT model
print("Re-downloading facebook/VGGT-1B ...")
self.model = VGGT.from_pretrained(
"facebook/VGGT-1B", force_download=True
)
else:
# Load the VGGT class
self.model = VGGT(
depth=depth,
num_heads=num_heads,
intermediate_layer_idx=intermediate_layer_idx,
)
# Get the dtype for VGGT inference
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
self.dtype = (
torch.bfloat16
if torch.cuda.get_device_capability()[0] >= 8
else torch.float16
)
# Load custom checkpoint if requested
if self.load_custom_ckpt:
print(f"Loading checkpoint from {self.custom_ckpt_path} ...")
assert self.custom_ckpt_path is not None, (
"custom_ckpt_path must be provided if load_custom_ckpt is set to True"
)
custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False)
print(self.model.load_state_dict(custom_ckpt, strict=True))
del custom_ckpt # in case it occupies memory
def forward(self, views):
"""
Forward pass wrapper for VGGT
Assumption:
- All the input views have the same image shape.
Args:
views (List[dict]): List of dictionaries containing the input views' images and instance information.
Each dictionary should contain the following keys:
"img" (tensor): Image tensor of shape (B, C, H, W).
"data_norm_type" (list): ["identity"]
Returns:
List[dict]: A list containing the final outputs for all N views.
"""
# Get input shape of the images, number of views, and batch size per view
batch_size_per_view, _, height, width = views[0]["img"].shape
num_views = len(views)
# Check the data norm type
# VGGT expects a normalized image but without the DINOv2 mean and std applied ("identity")
data_norm_type = views[0]["data_norm_type"][0]
assert data_norm_type == "identity", (
"VGGT expects a normalized image but without the DINOv2 mean and std applied"
)
# Concatenate the images to create a single (B, V, C, H, W) tensor
img_list = [view["img"] for view in views]
images = torch.stack(img_list, dim=1)
# Run the VGGT aggregator
with torch.autocast("cuda", dtype=self.dtype):
aggregated_tokens_list, ps_idx = self.model.aggregator(images)
# Run the Camera + Pose Branch of VGGT
with torch.autocast("cuda", enabled=False):
# Predict Cameras
pose_enc = self.model.camera_head(aggregated_tokens_list)[-1]
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
# Extrinsics Shape: (B, V, 3, 4)
# Intrinsics Shape: (B, V, 3, 3)
extrinsic, intrinsic = pose_encoding_to_extri_intri(
pose_enc, images.shape[-2:]
)
# Predict Depth Maps
# Depth Shape: (B, V, H, W, 1)
# Depth Confidence Shape: (B, V, H, W)
depth_map, depth_conf = self.model.depth_head(
aggregated_tokens_list, images, ps_idx
)
# Convert the output to MapAnything format
res = []
for view_idx in range(num_views):
# Get the extrinsics, intrinsics, depth map for the current view
curr_view_extrinsic = extrinsic[:, view_idx, ...]
curr_view_extrinsic = closed_form_inverse_se3(
curr_view_extrinsic
) # Convert to cam2world
curr_view_intrinsic = intrinsic[:, view_idx, ...]
curr_view_depth_z = depth_map[:, view_idx, ...]
curr_view_depth_z = curr_view_depth_z.squeeze(-1)
curr_view_confidence = depth_conf[:, view_idx, ...]
# Get the camera frame pointmaps
curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
curr_view_depth_z, curr_view_intrinsic
)
# Convert the extrinsics to quaternions and translations
curr_view_cam_translations = curr_view_extrinsic[..., :3, 3]
curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3])
# Convert the z depth to depth along ray
curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
curr_view_depth_z, curr_view_intrinsic
)
curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
# Get the ray directions on the unit sphere in the camera frame
_, curr_view_ray_dirs = get_rays_in_camera_frame(
curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
)
# Get the pointmaps
curr_view_pts3d = (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
curr_view_ray_dirs,
curr_view_depth_along_ray,
curr_view_cam_translations,
curr_view_cam_quats,
)
)
# Append the outputs to the result list
res.append(
{
"pts3d": curr_view_pts3d,
"pts3d_cam": curr_view_pts3d_cam,
"ray_directions": curr_view_ray_dirs,
"depth_along_ray": curr_view_depth_along_ray,
"cam_trans": curr_view_cam_translations,
"cam_quats": curr_view_cam_quats,
"conf": curr_view_confidence,
}
)
return res
|