File size: 9,199 Bytes
c20d7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
"""Contains Dense Transformer Prediction architecture.

Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413

For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""

from __future__ import annotations

from typing import NamedTuple

import torch
import torch.nn as nn

from sharp.models.blocks import (
    FeatureFusionBlock2d,
    NormLayerName,
    residual_block_2d,
)
from sharp.models.decoders import BaseDecoder, MultiresConvDecoder
from sharp.models.params import DPTImageEncoderType, GaussianDecoderParams


def create_gaussian_decoder(
    params: GaussianDecoderParams, dims_depth_features: list[int]
) -> GaussianDensePredictionTransformer:
    """Create gaussian_decoder model specified by gaussian_decoder_name."""
    decoder = MultiresConvDecoder(
        dims_depth_features,
        params.dims_decoder,
        grad_checkpointing=params.grad_checkpointing,
        upsampling_mode=params.upsampling_mode,
    )

    return GaussianDensePredictionTransformer(
        decoder=decoder,
        dim_in=params.dim_in,
        dim_out=params.dim_out,
        stride_out=params.stride,
        norm_type=params.norm_type,
        norm_num_groups=params.norm_num_groups,
        use_depth_input=params.use_depth_input,
        grad_checkpointing=params.grad_checkpointing,
        image_encoder_type=params.image_encoder_type,
        image_encoder_params=params,
    )


def _create_project_upsample_block(
    dim_in: int,
    dim_out: int,
    upsample_layers: int,
    dim_intermediate: int | None = None,
) -> nn.Module:
    if dim_intermediate is None:
        dim_intermediate = dim_out
    # Projection.
    blocks = [
        nn.Conv2d(
            in_channels=dim_in,
            out_channels=dim_intermediate,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
        )
    ]

    # Upsampling.
    blocks += [
        nn.ConvTranspose2d(
            in_channels=dim_intermediate if i == 0 else dim_out,
            out_channels=dim_out,
            kernel_size=2,
            stride=2,
            padding=0,
            bias=False,
        )
        for i in range(upsample_layers)
    ]

    return nn.Sequential(*blocks)


class ImageFeatures(NamedTuple):
    """Image feature extracted from decoder."""

    texture_features: torch.Tensor
    geometry_features: torch.Tensor


class SkipConvBackbone(nn.Module):
    """A wrapper around a conv layer that behaves like a BaseBackbone."""

    def __init__(self, dim_in: int, dim_out: int, kernel_size: int, stride_out: int):
        """Initialize SkipConvBackbone."""
        super().__init__()
        self.stride_out = stride_out
        if stride_out == 1 and kernel_size != 1:
            raise ValueError("We only support kernel_size = 1 if stride_out is 1.")
        padding: int = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(
            dim_in, dim_out, kernel_size=kernel_size, stride=stride_out, padding=padding
        )

    def forward(
        self,
        input_features: torch.Tensor,
        encodings: list[torch.Tensor] | None = None,
    ) -> ImageFeatures:
        """Apply SkipConvBackbone to image."""
        output = self.conv(input_features)
        return ImageFeatures(
            texture_features=output,
            geometry_features=output,
        )

    @property
    def stride(self) -> int:
        """Effective downsampling stride."""
        return self.stride_out


class GaussianDensePredictionTransformer(nn.Module):
    """Dense Prediction Transformer for Gaussian.

    Reuse monodepth decoded features for processing.
    """

    norm_type: NormLayerName

    def __init__(
        self,
        decoder: BaseDecoder,
        dim_in: int,
        dim_out: int,
        stride_out: int,
        image_encoder_params: GaussianDecoderParams,
        image_encoder_type: DPTImageEncoderType = "skip_conv",
        norm_type: NormLayerName = "group_norm",
        norm_num_groups: int = 8,
        use_depth_input: bool = True,
        grad_checkpointing: bool = False,
    ):
        """Initialize Dense Prediction Transformer for Gaussian.

        Args:
            decoder: Decoder to decode features.
            monodepth_decoder: Optional monodepth decoder to fuse monodepth decoded features.
            dim_in: Input dimension.
            dim_out: Final output dimension.
            stride_out: Stride of output feature map.
            image_encoder_params: The backbone parameters to configurate the image encoder.
            image_encoder_type: Type of image encoder to use.
            encoder: Encoder to generate features using monodepth model.
            norm_type: Type of norm layers.
            norm_num_groups: Num groups for norm layers.
            use_depth_input: Whether to use depth input.
            grad_checkpointing: Whether to use gradient checkpointing.
        """
        super().__init__()

        self.decoder = decoder
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.stride_out = stride_out
        self.norm_type = norm_type
        self.norm_num_groups = norm_num_groups
        self.use_depth_input = use_depth_input
        self.grad_checkpointing = grad_checkpointing
        self.image_encoder_type = image_encoder_type

        # Adopt an image encoder to lift dimension to monodepth feature and
        # resize to be the same resolution as the decoder output.
        dim_in = self.dim_in if use_depth_input else self.dim_in - 1
        image_encoder_params.dim_in = dim_in
        image_encoder_params.dim_out = decoder.dim_out
        self.image_encoder = self._create_image_encoder(image_encoder_params, stride_out)

        self.fusion = FeatureFusionBlock2d(decoder.dim_out)

        if stride_out == 1:
            self.upsample = _create_project_upsample_block(
                decoder.dim_out,
                decoder.dim_out,
                upsample_layers=1,
            )
        elif stride_out == 2:
            self.upsample = nn.Identity()
        else:
            raise ValueError("We only support stride is 1 or 2 for DPT backbone.")

        self.texture_head = self._create_head(dim_decoder=decoder.dim_out, dim_out=self.dim_out)
        self.geometry_head = self._create_head(dim_decoder=decoder.dim_out, dim_out=self.dim_out)

    def _create_head(self, dim_decoder: int, dim_out: int) -> nn.Module:
        return nn.Sequential(
            residual_block_2d(
                dim_in=dim_decoder,
                dim_out=dim_decoder,
                dim_hidden=dim_decoder // 2,
                norm_type=self.norm_type,
                norm_num_groups=self.norm_num_groups,
            ),
            residual_block_2d(
                dim_in=dim_decoder,
                dim_hidden=dim_decoder // 2,
                dim_out=dim_decoder,
                norm_type=self.norm_type,
                norm_num_groups=self.norm_num_groups,
            ),
            nn.ReLU(),
            nn.Conv2d(dim_decoder, dim_out, kernel_size=1, stride=1),
            nn.ReLU(),
        )

    def _create_image_encoder(
        self, image_encoder_params: GaussianDecoderParams, stride_out: int
    ) -> nn.Module:
        """Create image encoder and return based on parameters."""
        if self.image_encoder_type == "skip_conv":
            # Use kernel_size = 1 only if stride_out is 1.
            return SkipConvBackbone(
                image_encoder_params.dim_in,
                image_encoder_params.dim_out,
                kernel_size=3 if stride_out != 1 else 1,
                stride_out=stride_out,
            )
        elif self.image_encoder_type == "skip_conv_kernel2":
            return SkipConvBackbone(
                image_encoder_params.dim_in,
                image_encoder_params.dim_out,
                kernel_size=stride_out,
                stride_out=stride_out,
            )
        else:
            raise ValueError(f"Unsupported image encoder type: {self.image_encoder_type}")

    def forward(self, input_features: torch.Tensor, encodings: list[torch.Tensor]) -> ImageFeatures:
        """Run monodepth and fuse features with input image to predict Gaussians.

        Args:
            input_features: The input features to use.
            encodings: Feature encodings (e.g. from monodepth network).
        """
        features = self.decoder(encodings).contiguous()
        features = self.upsample(features)

        if self.use_depth_input:
            skip_features = self.image_encoder(input_features).texture_features
        else:
            skip_features = self.image_encoder(input_features[:, :3].contiguous())
        features = self.fusion(features, skip_features)

        texture_features = self.texture_head(features)
        geometry_features = self.geometry_head(features)

        return ImageFeatures(
            texture_features=texture_features,  # type: ignore
            geometry_features=geometry_features,  # type: ignore
        )

    @property
    def stride(self) -> int:
        """Internal stride of GaussianDensePredictionTransformer."""
        return self.stride_out