Spaces:
Sleeping
Sleeping
File size: 4,894 Bytes
b74998d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
"""
Inference wrapper for Pi3
"""
import torch
from mapanything.models.external.pi3.models.pi3 import Pi3
from mapanything.models.external.vggt.utils.rotation import mat_to_quat
class Pi3Wrapper(torch.nn.Module):
def __init__(
self,
name,
torch_hub_force_reload,
load_pretrained_weights=True,
pos_type="rope100",
decoder_size="large",
):
super().__init__()
self.name = name
self.torch_hub_force_reload = torch_hub_force_reload
if load_pretrained_weights:
# Load pre-trained weights
if not torch_hub_force_reload:
# Initialize the Pi3 model from huggingface hub cache
print("Loading Pi3 from huggingface cache ...")
self.model = Pi3.from_pretrained(
"yyfz233/Pi3",
)
else:
# Initialize the Pi3 model
self.model = Pi3.from_pretrained("yyfz233/Pi3", force_download=True)
else:
# Load the Pi3 class
self.model = Pi3(
pos_type=pos_type,
decoder_size=decoder_size,
)
# Get the dtype for Pi3 inference
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
self.dtype = (
torch.bfloat16
if torch.cuda.get_device_capability()[0] >= 8
else torch.float16
)
def forward(self, views):
"""
Forward pass wrapper for Pi3
Assumption:
- All the input views have the same image shape.
Args:
views (List[dict]): List of dictionaries containing the input views' images and instance information.
Each dictionary should contain the following keys:
"img" (tensor): Image tensor of shape (B, C, H, W).
"data_norm_type" (list): ["identity"]
Returns:
List[dict]: A list containing the final outputs for all N views.
"""
# Get input shape of the images, number of views, and batch size per view
batch_size_per_view, _, height, width = views[0]["img"].shape
num_views = len(views)
# Check the data norm type
# Pi3 expects a normalized image but without the DINOv2 mean and std applied ("identity")
data_norm_type = views[0]["data_norm_type"][0]
assert data_norm_type == "identity", (
"Pi3 expects a normalized image but without the DINOv2 mean and std applied"
)
# Concatenate the images to create a single (B, V, C, H, W) tensor
img_list = [view["img"] for view in views]
images = torch.stack(img_list, dim=1)
# Run the Pi3 aggregator
with torch.autocast("cuda", dtype=self.dtype):
results = self.model(images)
# Need high precision for transformations
with torch.autocast("cuda", enabled=False):
# Convert the output to MapAnything format
res = []
for view_idx in range(num_views):
# Get the extrinsics
curr_view_extrinsic = results["camera_poses"][:, view_idx, ...]
curr_view_cam_translations = curr_view_extrinsic[..., :3, 3]
curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3])
# Get the depth along ray, ray directions, local point cloud & global point cloud
curr_view_pts3d_cam = results["local_points"][:, view_idx, ...]
curr_view_depth_along_ray = torch.norm(
curr_view_pts3d_cam, dim=-1, keepdim=True
)
curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray
curr_view_pts3d = results["points"][:, view_idx, ...]
# Get the confidence
curr_view_confidence = results["conf"][:, view_idx, ...]
# Append the outputs to the result list
res.append(
{
"pts3d": curr_view_pts3d,
"pts3d_cam": curr_view_pts3d_cam,
"ray_directions": curr_view_ray_dirs,
"depth_along_ray": curr_view_depth_along_ray,
"cam_trans": curr_view_cam_translations,
"cam_quats": curr_view_cam_quats,
"conf": curr_view_confidence,
}
)
return res
|