File size: 7,744 Bytes
c20d7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""Contains definition of RGB-only gaussian predictor.

For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""

from __future__ import annotations

import logging

import torch
from torch import nn

from sharp.models.monodepth import MonodepthWithEncodingAdaptor
from sharp.utils.gaussians import Gaussians3D

from .composer import GaussianComposer

LOGGER = logging.getLogger(__name__)


class DepthAlignment(nn.Module):
    """Depth alignment in a dedicated nn.Module.

    Wrap scale_map_estimator to perform the conditional logic in a separated torch
    module outside the forward of RGBGaussianPredictor. This module can be then
    excluded during symbolic tracing.
    """

    def __init__(self, scale_map_estimator: nn.Module | None):
        """Initialize DepthAlignmentWrapper.

        Args:
            scale_map_estimator: Module to align monodepth to ground truth depth.
        """
        super().__init__()
        self.scale_map_estimator = scale_map_estimator

    def forward(
        self,
        monodepth: torch.Tensor,
        depth: torch.Tensor,
        depth_decoder_features: torch.Tensor | None = None,
    ):
        """Optionally align monodepth to ground truth with a local scale map.

        Args:
            monodepth: The monodepth model with intermediate features to use.
            depth: Ground truth depth to align predicted depth to.
            depth_decoder_features: The (optional) monodepth decoder features.
        """
        if depth is not None and self.scale_map_estimator is not None:
            depth_alignment_map = self.scale_map_estimator(
                monodepth[:, 0:1], depth, depth_decoder_features
            )
            monodepth = depth_alignment_map * monodepth
        else:
            # Some losses rely on the presence of an alignment map.
            # We ensure that they can be computed by creating a fake alignment map.
            depth_alignment_map = torch.ones_like(monodepth)
        return monodepth, depth_alignment_map


class RGBGaussianPredictor(nn.Module):
    """Predicts 3D Gaussians from images."""

    feature_model: nn.Module

    def __init__(
        self,
        init_model: nn.Module,
        monodepth_model: MonodepthWithEncodingAdaptor,
        feature_model: nn.Module,
        prediction_head: nn.Module,
        gaussian_composer: GaussianComposer,
        scale_map_estimator: nn.Module | None,
    ) -> None:
        """Initialize RGBGaussianPredictor.

        Args:
            init_model: A model mapping image and depth to base values.
            monodepth_model: The monodepth model with intermediate features to use.
            feature_model: The image2image model to predict Gaussians from.
            prediction_head: Head to decode image features.
            gaussian_composer: Module to compose final prediction from deltas and
                base values.
            scale_map_estimator: Module to align monodepth to ground truth depth.

        Note:
        ----
            when monodepth_model is trainable, using local depth alignment can
            result in the monodepth model losing its ability to predict shapes. It is
            hence recommend to deactivate the corresponding flag.
        """
        super().__init__()
        self.init_model = init_model
        self.feature_model = feature_model
        self.monodepth_model = monodepth_model
        self.prediction_head = prediction_head
        self.gaussian_composer = gaussian_composer
        self.depth_alignment = DepthAlignment(scale_map_estimator)

    def forward(
        self,
        image: torch.Tensor,
        disparity_factor: torch.Tensor,
        depth: torch.Tensor | None = None,
    ) -> Gaussians3D:
        """Predict 3D Gaussians.

        Args:
            image: The image to process.
            disparity_factor: Factor to convert depth to disparities.
            depth: Ground truth depth to align predicted depth to.

        Returns:
            The predicted 3D Gaussians.

        Note:
        ----
        During training, it is recommended to feed an additional ground truth depth
        map to the network to align the predicted depth to. During inference, it is
        recommended to use depth_gt=None and use monodepth_disparity output from the
        model instead to compute depth.
        """
        # Estimate depth and align to ground truth (if available).
        monodepth_output = self.monodepth_model(image)
        monodepth_disparity = monodepth_output.disparity

        disparity_factor = disparity_factor[:, None, None, None]
        monodepth = disparity_factor / monodepth_disparity.clamp(min=1e-4, max=1e4)

        # In the model we apply additional alignment to provided ground truth depth
        # as well as additional normalization.
        #
        # The overall graph looks as follows:
        #
        #     monodepth        depth    # Both monodepth and depth are metric here.
        #         |              |
        #         +------+-------+
        #                |
        #        +-------+--------+     # Optionally align monodepth to ground truth
        #        |depth_alignement|     # with a local scale map.
        #        +-------+--------+
        #                |
        #                v
        #       monodepth (aligned)     # Monodepth is now aligned to ground truth.
        #                |
        #          +-----+----+         # Normalize depth and compute base gaussians.
        #          |init_model|         # in these normalized coordinates.
        #          +-----+----+
        #                |
        #                v
        #   +------ init_output         # Init_output consists of features, base
        #   |            |              # gaussians and a global scale.
        #   |     +------+-----+
        #   |     |main network|        # Compute delta values to base gaussians.
        #   |     +------+-----+
        #   |            |
        #   |            V
        #   |        delta_values       # The delta values are computed with normalized depth.
        #   |            |
        #   |    +-------+---------+
        #   +--> |gaussian_composer|    # Add delta to base values and unscale gaussians.
        #        +-------+---------+
        #                |
        #                v
        #            gaussians          # The final Gaussians are metric again.
        #

        # The logic to decide whether to align monodepth to the ground truth is wrapped
        # in a submodule 'DepthAlignement' to facilitate the symbolic tracing of the
        # predictor. This way, the depth alignment submodule containing the conditional
        # logic can be excluded during the tracing and the graph of the predictors is
        # static.
        monodepth, _ = self.depth_alignment(
            monodepth,
            depth,
            monodepth_output.decoder_features,
        )

        init_output = self.init_model(image, monodepth)
        image_features = self.feature_model(
            init_output.feature_input, encodings=monodepth_output.output_features
        )
        delta_values = self.prediction_head(image_features)
        gaussians = self.gaussian_composer(
            delta=delta_values,
            base_values=init_output.gaussian_base_values,
            global_scale=init_output.global_scale,
        )
        return gaussians

    def internal_resolution(self) -> int:
        """Internal resolution."""
        return self.monodepth_model.internal_resolution()

    @property
    def output_resolution(self) -> int:
        """Output resolution of Gaussians."""
        return self.internal_resolution() // 2