Spaces:
Sleeping
Sleeping
File size: 4,912 Bytes
b74998d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
"""
Inference wrapper for MoGe
"""
import torch
from mapanything.models.external.moge.models.v1 import MoGeModel as MoGeModelV1
from mapanything.models.external.moge.models.v2 import MoGeModel as MoGeModelV2
class MoGeWrapper(torch.nn.Module):
def __init__(
self,
name,
model_string="Ruicheng/moge-2-vitl",
torch_hub_force_reload=False,
load_custom_ckpt=False,
custom_ckpt_path=None,
):
super().__init__()
self.name = name
self.model_string = model_string
self.torch_hub_force_reload = torch_hub_force_reload
self.load_custom_ckpt = load_custom_ckpt
self.custom_ckpt_path = custom_ckpt_path
# Mapping of MoGe model version to checkpoint strings
self.moge_model_map = {
"v1": ["Ruicheng/moge-vitl"],
"v2": [
"Ruicheng/moge-2-vits-normal",
"Ruicheng/moge-2-vitb-normal",
"Ruicheng/moge-2-vitl-normal",
"Ruicheng/moge-2-vitl",
],
}
# Initialize the model
if self.model_string in self.moge_model_map["v1"]:
self.model = MoGeModelV1.from_pretrained(self.model_string)
elif self.model_string in self.moge_model_map["v2"]:
self.model = MoGeModelV2.from_pretrained(self.model_string)
else:
raise ValueError(
f"Invalid model string: {self.model_string}. Valid strings are: {self.moge_model_map}"
)
# Load custom checkpoint if requested
if self.load_custom_ckpt:
print(f"Loading checkpoint from {self.custom_ckpt_path} ...")
assert self.custom_ckpt_path is not None, (
"custom_ckpt_path must be provided if load_custom_ckpt is set to True"
)
custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False)
print(self.model.load_state_dict(custom_ckpt, strict=True))
del custom_ckpt # in case it occupies memory
def forward(self, views):
"""
Forward pass wrapper for MoGe-2.
The predicted MoGe-2 mask is not applied to the outputs.
The number of tokens for inference is determined by the image shape.
Assumption:
- The number of input views is 1.
Args:
views (List[dict]): List of dictionaries containing the input views' images and instance information.
Length of the list should be 1.
Each dictionary should contain the following keys:
"img" (tensor): Image tensor of shape (B, C, H, W).
"data_norm_type" (list): ["identity"]
Returns:
List[dict]: A list containing the final outputs for the single view. Length of the list will be 1.
"""
# Check that the number of input views is 1
assert len(views) == 1, "MoGe only supports 1 input view."
# Get input shape of the images, number of tokens for inference, and batch size per view
_, _, height, width = views[0]["img"].shape
num_tokens = int(height // 14) * int(width // 14)
# Check the data norm type
# MoGe expects a normalized image but without the DINOv2 mean and std applied ("identity")
data_norm_type = views[0]["data_norm_type"][0]
assert data_norm_type == "identity", (
"MoGe expects a normalized image but without the DINOv2 mean and std applied"
)
# Run MoGe inference
# Output dict contains: "points", "depth", "mask", "intrinsics", "normal" (based on model config)
model_outputs = self.model.infer(
image=views[0]["img"], num_tokens=num_tokens, apply_mask=False
)
# Get the ray directions and depth along ray
with torch.autocast("cuda", enabled=False):
depth_along_ray = torch.norm(model_outputs["points"], dim=-1, keepdim=True)
ray_directions = model_outputs["points"] / depth_along_ray
# Convert the output to MapAnything format
result_dict = {
"pts3d": model_outputs["points"],
"pts3d_cam": model_outputs["points"],
"depth_z": model_outputs["depth"].unsqueeze(-1),
"intrinsics": model_outputs["intrinsics"],
"non_ambiguous_mask": model_outputs["mask"],
"ray_directions": ray_directions,
"depth_along_ray": depth_along_ray,
}
res = [result_dict]
return res
|