File size: 4,912 Bytes
b74998d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (c) Meta Platforms, Inc. and affiliates.

# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

"""

Inference wrapper for MoGe

"""

import torch

from mapanything.models.external.moge.models.v1 import MoGeModel as MoGeModelV1
from mapanything.models.external.moge.models.v2 import MoGeModel as MoGeModelV2


class MoGeWrapper(torch.nn.Module):
    def __init__(

        self,

        name,

        model_string="Ruicheng/moge-2-vitl",

        torch_hub_force_reload=False,

        load_custom_ckpt=False,

        custom_ckpt_path=None,

    ):
        super().__init__()
        self.name = name
        self.model_string = model_string
        self.torch_hub_force_reload = torch_hub_force_reload
        self.load_custom_ckpt = load_custom_ckpt
        self.custom_ckpt_path = custom_ckpt_path

        # Mapping of MoGe model version to checkpoint strings
        self.moge_model_map = {
            "v1": ["Ruicheng/moge-vitl"],
            "v2": [
                "Ruicheng/moge-2-vits-normal",
                "Ruicheng/moge-2-vitb-normal",
                "Ruicheng/moge-2-vitl-normal",
                "Ruicheng/moge-2-vitl",
            ],
        }

        # Initialize the model
        if self.model_string in self.moge_model_map["v1"]:
            self.model = MoGeModelV1.from_pretrained(self.model_string)
        elif self.model_string in self.moge_model_map["v2"]:
            self.model = MoGeModelV2.from_pretrained(self.model_string)
        else:
            raise ValueError(
                f"Invalid model string: {self.model_string}. Valid strings are: {self.moge_model_map}"
            )

        # Load custom checkpoint if requested
        if self.load_custom_ckpt:
            print(f"Loading checkpoint from {self.custom_ckpt_path} ...")
            assert self.custom_ckpt_path is not None, (
                "custom_ckpt_path must be provided if load_custom_ckpt is set to True"
            )
            custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False)
            print(self.model.load_state_dict(custom_ckpt, strict=True))
            del custom_ckpt  # in case it occupies memory

    def forward(self, views):
        """

        Forward pass wrapper for MoGe-2.

        The predicted MoGe-2 mask is not applied to the outputs.

        The number of tokens for inference is determined by the image shape.



        Assumption:

        - The number of input views is 1.



        Args:

            views (List[dict]): List of dictionaries containing the input views' images and instance information.

                                Length of the list should be 1.

                                Each dictionary should contain the following keys:

                                    "img" (tensor): Image tensor of shape (B, C, H, W).

                                    "data_norm_type" (list): ["identity"]



        Returns:

            List[dict]: A list containing the final outputs for the single view. Length of the list will be 1.

        """
        # Check that the number of input views is 1
        assert len(views) == 1, "MoGe only supports 1 input view."

        # Get input shape of the images, number of tokens for inference, and batch size per view
        _, _, height, width = views[0]["img"].shape
        num_tokens = int(height // 14) * int(width // 14)

        # Check the data norm type
        # MoGe expects a normalized image but without the DINOv2 mean and std applied ("identity")
        data_norm_type = views[0]["data_norm_type"][0]
        assert data_norm_type == "identity", (
            "MoGe expects a normalized image but without the DINOv2 mean and std applied"
        )

        # Run MoGe inference
        # Output dict contains: "points", "depth", "mask", "intrinsics", "normal" (based on model config)
        model_outputs = self.model.infer(
            image=views[0]["img"], num_tokens=num_tokens, apply_mask=False
        )

        # Get the ray directions and depth along ray
        with torch.autocast("cuda", enabled=False):
            depth_along_ray = torch.norm(model_outputs["points"], dim=-1, keepdim=True)
            ray_directions = model_outputs["points"] / depth_along_ray

        # Convert the output to MapAnything format
        result_dict = {
            "pts3d": model_outputs["points"],
            "pts3d_cam": model_outputs["points"],
            "depth_z": model_outputs["depth"].unsqueeze(-1),
            "intrinsics": model_outputs["intrinsics"],
            "non_ambiguous_mask": model_outputs["mask"],
            "ray_directions": ray_directions,
            "depth_along_ray": depth_along_ray,
        }
        res = [result_dict]

        return res