ColamanAI's picture
Upload 169 files
b74998d verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
"""
Inference wrapper for VGGT
"""
import torch
from mapanything.models.external.vggt.models.vggt import VGGT
from mapanything.models.external.vggt.utils.geometry import closed_form_inverse_se3
from mapanything.models.external.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from mapanything.models.external.vggt.utils.rotation import mat_to_quat
from mapanything.utils.geometry import (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
convert_z_depth_to_depth_along_ray,
depthmap_to_camera_frame,
get_rays_in_camera_frame,
)
class VGGTWrapper(torch.nn.Module):
def __init__(
self,
name,
torch_hub_force_reload,
load_pretrained_weights=True,
depth=24,
num_heads=16,
intermediate_layer_idx=[4, 11, 17, 23],
load_custom_ckpt=False,
custom_ckpt_path=None,
):
super().__init__()
self.name = name
self.torch_hub_force_reload = torch_hub_force_reload
self.load_custom_ckpt = load_custom_ckpt
self.custom_ckpt_path = custom_ckpt_path
if load_pretrained_weights:
# Load pre-trained weights
if not torch_hub_force_reload:
# Initialize the 1B VGGT model from huggingface hub cache
print("Loading facebook/VGGT-1B from huggingface cache ...")
self.model = VGGT.from_pretrained(
"facebook/VGGT-1B",
)
else:
# Initialize the 1B VGGT model
print("Re-downloading facebook/VGGT-1B ...")
self.model = VGGT.from_pretrained(
"facebook/VGGT-1B", force_download=True
)
else:
# Load the VGGT class
self.model = VGGT(
depth=depth,
num_heads=num_heads,
intermediate_layer_idx=intermediate_layer_idx,
)
# Get the dtype for VGGT inference
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
self.dtype = (
torch.bfloat16
if torch.cuda.get_device_capability()[0] >= 8
else torch.float16
)
# Load custom checkpoint if requested
if self.load_custom_ckpt:
print(f"Loading checkpoint from {self.custom_ckpt_path} ...")
assert self.custom_ckpt_path is not None, (
"custom_ckpt_path must be provided if load_custom_ckpt is set to True"
)
custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False)
print(self.model.load_state_dict(custom_ckpt, strict=True))
del custom_ckpt # in case it occupies memory
def forward(self, views):
"""
Forward pass wrapper for VGGT
Assumption:
- All the input views have the same image shape.
Args:
views (List[dict]): List of dictionaries containing the input views' images and instance information.
Each dictionary should contain the following keys:
"img" (tensor): Image tensor of shape (B, C, H, W).
"data_norm_type" (list): ["identity"]
Returns:
List[dict]: A list containing the final outputs for all N views.
"""
# Get input shape of the images, number of views, and batch size per view
batch_size_per_view, _, height, width = views[0]["img"].shape
num_views = len(views)
# Check the data norm type
# VGGT expects a normalized image but without the DINOv2 mean and std applied ("identity")
data_norm_type = views[0]["data_norm_type"][0]
assert data_norm_type == "identity", (
"VGGT expects a normalized image but without the DINOv2 mean and std applied"
)
# Concatenate the images to create a single (B, V, C, H, W) tensor
img_list = [view["img"] for view in views]
images = torch.stack(img_list, dim=1)
# Run the VGGT aggregator
with torch.autocast("cuda", dtype=self.dtype):
aggregated_tokens_list, ps_idx = self.model.aggregator(images)
# Run the Camera + Pose Branch of VGGT
with torch.autocast("cuda", enabled=False):
# Predict Cameras
pose_enc = self.model.camera_head(aggregated_tokens_list)[-1]
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
# Extrinsics Shape: (B, V, 3, 4)
# Intrinsics Shape: (B, V, 3, 3)
extrinsic, intrinsic = pose_encoding_to_extri_intri(
pose_enc, images.shape[-2:]
)
# Predict Depth Maps
# Depth Shape: (B, V, H, W, 1)
# Depth Confidence Shape: (B, V, H, W)
depth_map, depth_conf = self.model.depth_head(
aggregated_tokens_list, images, ps_idx
)
# Convert the output to MapAnything format
res = []
for view_idx in range(num_views):
# Get the extrinsics, intrinsics, depth map for the current view
curr_view_extrinsic = extrinsic[:, view_idx, ...]
curr_view_extrinsic = closed_form_inverse_se3(
curr_view_extrinsic
) # Convert to cam2world
curr_view_intrinsic = intrinsic[:, view_idx, ...]
curr_view_depth_z = depth_map[:, view_idx, ...]
curr_view_depth_z = curr_view_depth_z.squeeze(-1)
curr_view_confidence = depth_conf[:, view_idx, ...]
# Get the camera frame pointmaps
curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
curr_view_depth_z, curr_view_intrinsic
)
# Convert the extrinsics to quaternions and translations
curr_view_cam_translations = curr_view_extrinsic[..., :3, 3]
curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3])
# Convert the z depth to depth along ray
curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
curr_view_depth_z, curr_view_intrinsic
)
curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
# Get the ray directions on the unit sphere in the camera frame
_, curr_view_ray_dirs = get_rays_in_camera_frame(
curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
)
# Get the pointmaps
curr_view_pts3d = (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
curr_view_ray_dirs,
curr_view_depth_along_ray,
curr_view_cam_translations,
curr_view_cam_quats,
)
)
# Append the outputs to the result list
res.append(
{
"pts3d": curr_view_pts3d,
"pts3d_cam": curr_view_pts3d_cam,
"ray_directions": curr_view_ray_dirs,
"depth_along_ray": curr_view_depth_along_ray,
"cam_trans": curr_view_cam_translations,
"cam_quats": curr_view_cam_quats,
"conf": curr_view_confidence,
}
)
return res