File size: 4,894 Bytes
b74998d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (c) Meta Platforms, Inc. and affiliates.

# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

"""

Inference wrapper for Pi3

"""

import torch

from mapanything.models.external.pi3.models.pi3 import Pi3
from mapanything.models.external.vggt.utils.rotation import mat_to_quat


class Pi3Wrapper(torch.nn.Module):
    def __init__(

        self,

        name,

        torch_hub_force_reload,

        load_pretrained_weights=True,

        pos_type="rope100",

        decoder_size="large",

    ):
        super().__init__()
        self.name = name
        self.torch_hub_force_reload = torch_hub_force_reload

        if load_pretrained_weights:
            # Load pre-trained weights
            if not torch_hub_force_reload:
                # Initialize the Pi3 model from huggingface hub cache
                print("Loading Pi3 from huggingface cache ...")
                self.model = Pi3.from_pretrained(
                    "yyfz233/Pi3",
                )
            else:
                # Initialize the Pi3 model
                self.model = Pi3.from_pretrained("yyfz233/Pi3", force_download=True)
        else:
            # Load the Pi3 class
            self.model = Pi3(
                pos_type=pos_type,
                decoder_size=decoder_size,
            )

        # Get the dtype for Pi3 inference
        # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
        self.dtype = (
            torch.bfloat16
            if torch.cuda.get_device_capability()[0] >= 8
            else torch.float16
        )

    def forward(self, views):
        """

        Forward pass wrapper for Pi3



        Assumption:

        - All the input views have the same image shape.



        Args:

            views (List[dict]): List of dictionaries containing the input views' images and instance information.

                                Each dictionary should contain the following keys:

                                    "img" (tensor): Image tensor of shape (B, C, H, W).

                                    "data_norm_type" (list): ["identity"]



        Returns:

            List[dict]: A list containing the final outputs for all N views.

        """
        # Get input shape of the images, number of views, and batch size per view
        batch_size_per_view, _, height, width = views[0]["img"].shape
        num_views = len(views)

        # Check the data norm type
        # Pi3 expects a normalized image but without the DINOv2 mean and std applied ("identity")
        data_norm_type = views[0]["data_norm_type"][0]
        assert data_norm_type == "identity", (
            "Pi3 expects a normalized image but without the DINOv2 mean and std applied"
        )

        # Concatenate the images to create a single (B, V, C, H, W) tensor
        img_list = [view["img"] for view in views]
        images = torch.stack(img_list, dim=1)

        # Run the Pi3 aggregator
        with torch.autocast("cuda", dtype=self.dtype):
            results = self.model(images)

        # Need high precision for transformations
        with torch.autocast("cuda", enabled=False):
            # Convert the output to MapAnything format
            res = []
            for view_idx in range(num_views):
                # Get the extrinsics
                curr_view_extrinsic = results["camera_poses"][:, view_idx, ...]
                curr_view_cam_translations = curr_view_extrinsic[..., :3, 3]
                curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3])

                # Get the depth along ray, ray directions, local point cloud & global point cloud
                curr_view_pts3d_cam = results["local_points"][:, view_idx, ...]
                curr_view_depth_along_ray = torch.norm(
                    curr_view_pts3d_cam, dim=-1, keepdim=True
                )
                curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray
                curr_view_pts3d = results["points"][:, view_idx, ...]

                # Get the confidence
                curr_view_confidence = results["conf"][:, view_idx, ...]

                # Append the outputs to the result list
                res.append(
                    {
                        "pts3d": curr_view_pts3d,
                        "pts3d_cam": curr_view_pts3d_cam,
                        "ray_directions": curr_view_ray_dirs,
                        "depth_along_ray": curr_view_depth_along_ray,
                        "cam_trans": curr_view_cam_translations,
                        "cam_quats": curr_view_cam_quats,
                        "conf": curr_view_confidence,
                    }
                )

        return res