File size: 9,199 Bytes
c20d7cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 |
"""Contains Dense Transformer Prediction architecture.
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from typing import NamedTuple
import torch
import torch.nn as nn
from sharp.models.blocks import (
FeatureFusionBlock2d,
NormLayerName,
residual_block_2d,
)
from sharp.models.decoders import BaseDecoder, MultiresConvDecoder
from sharp.models.params import DPTImageEncoderType, GaussianDecoderParams
def create_gaussian_decoder(
params: GaussianDecoderParams, dims_depth_features: list[int]
) -> GaussianDensePredictionTransformer:
"""Create gaussian_decoder model specified by gaussian_decoder_name."""
decoder = MultiresConvDecoder(
dims_depth_features,
params.dims_decoder,
grad_checkpointing=params.grad_checkpointing,
upsampling_mode=params.upsampling_mode,
)
return GaussianDensePredictionTransformer(
decoder=decoder,
dim_in=params.dim_in,
dim_out=params.dim_out,
stride_out=params.stride,
norm_type=params.norm_type,
norm_num_groups=params.norm_num_groups,
use_depth_input=params.use_depth_input,
grad_checkpointing=params.grad_checkpointing,
image_encoder_type=params.image_encoder_type,
image_encoder_params=params,
)
def _create_project_upsample_block(
dim_in: int,
dim_out: int,
upsample_layers: int,
dim_intermediate: int | None = None,
) -> nn.Module:
if dim_intermediate is None:
dim_intermediate = dim_out
# Projection.
blocks = [
nn.Conv2d(
in_channels=dim_in,
out_channels=dim_intermediate,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
]
# Upsampling.
blocks += [
nn.ConvTranspose2d(
in_channels=dim_intermediate if i == 0 else dim_out,
out_channels=dim_out,
kernel_size=2,
stride=2,
padding=0,
bias=False,
)
for i in range(upsample_layers)
]
return nn.Sequential(*blocks)
class ImageFeatures(NamedTuple):
"""Image feature extracted from decoder."""
texture_features: torch.Tensor
geometry_features: torch.Tensor
class SkipConvBackbone(nn.Module):
"""A wrapper around a conv layer that behaves like a BaseBackbone."""
def __init__(self, dim_in: int, dim_out: int, kernel_size: int, stride_out: int):
"""Initialize SkipConvBackbone."""
super().__init__()
self.stride_out = stride_out
if stride_out == 1 and kernel_size != 1:
raise ValueError("We only support kernel_size = 1 if stride_out is 1.")
padding: int = (kernel_size - 1) // 2
self.conv = nn.Conv2d(
dim_in, dim_out, kernel_size=kernel_size, stride=stride_out, padding=padding
)
def forward(
self,
input_features: torch.Tensor,
encodings: list[torch.Tensor] | None = None,
) -> ImageFeatures:
"""Apply SkipConvBackbone to image."""
output = self.conv(input_features)
return ImageFeatures(
texture_features=output,
geometry_features=output,
)
@property
def stride(self) -> int:
"""Effective downsampling stride."""
return self.stride_out
class GaussianDensePredictionTransformer(nn.Module):
"""Dense Prediction Transformer for Gaussian.
Reuse monodepth decoded features for processing.
"""
norm_type: NormLayerName
def __init__(
self,
decoder: BaseDecoder,
dim_in: int,
dim_out: int,
stride_out: int,
image_encoder_params: GaussianDecoderParams,
image_encoder_type: DPTImageEncoderType = "skip_conv",
norm_type: NormLayerName = "group_norm",
norm_num_groups: int = 8,
use_depth_input: bool = True,
grad_checkpointing: bool = False,
):
"""Initialize Dense Prediction Transformer for Gaussian.
Args:
decoder: Decoder to decode features.
monodepth_decoder: Optional monodepth decoder to fuse monodepth decoded features.
dim_in: Input dimension.
dim_out: Final output dimension.
stride_out: Stride of output feature map.
image_encoder_params: The backbone parameters to configurate the image encoder.
image_encoder_type: Type of image encoder to use.
encoder: Encoder to generate features using monodepth model.
norm_type: Type of norm layers.
norm_num_groups: Num groups for norm layers.
use_depth_input: Whether to use depth input.
grad_checkpointing: Whether to use gradient checkpointing.
"""
super().__init__()
self.decoder = decoder
self.dim_in = dim_in
self.dim_out = dim_out
self.stride_out = stride_out
self.norm_type = norm_type
self.norm_num_groups = norm_num_groups
self.use_depth_input = use_depth_input
self.grad_checkpointing = grad_checkpointing
self.image_encoder_type = image_encoder_type
# Adopt an image encoder to lift dimension to monodepth feature and
# resize to be the same resolution as the decoder output.
dim_in = self.dim_in if use_depth_input else self.dim_in - 1
image_encoder_params.dim_in = dim_in
image_encoder_params.dim_out = decoder.dim_out
self.image_encoder = self._create_image_encoder(image_encoder_params, stride_out)
self.fusion = FeatureFusionBlock2d(decoder.dim_out)
if stride_out == 1:
self.upsample = _create_project_upsample_block(
decoder.dim_out,
decoder.dim_out,
upsample_layers=1,
)
elif stride_out == 2:
self.upsample = nn.Identity()
else:
raise ValueError("We only support stride is 1 or 2 for DPT backbone.")
self.texture_head = self._create_head(dim_decoder=decoder.dim_out, dim_out=self.dim_out)
self.geometry_head = self._create_head(dim_decoder=decoder.dim_out, dim_out=self.dim_out)
def _create_head(self, dim_decoder: int, dim_out: int) -> nn.Module:
return nn.Sequential(
residual_block_2d(
dim_in=dim_decoder,
dim_out=dim_decoder,
dim_hidden=dim_decoder // 2,
norm_type=self.norm_type,
norm_num_groups=self.norm_num_groups,
),
residual_block_2d(
dim_in=dim_decoder,
dim_hidden=dim_decoder // 2,
dim_out=dim_decoder,
norm_type=self.norm_type,
norm_num_groups=self.norm_num_groups,
),
nn.ReLU(),
nn.Conv2d(dim_decoder, dim_out, kernel_size=1, stride=1),
nn.ReLU(),
)
def _create_image_encoder(
self, image_encoder_params: GaussianDecoderParams, stride_out: int
) -> nn.Module:
"""Create image encoder and return based on parameters."""
if self.image_encoder_type == "skip_conv":
# Use kernel_size = 1 only if stride_out is 1.
return SkipConvBackbone(
image_encoder_params.dim_in,
image_encoder_params.dim_out,
kernel_size=3 if stride_out != 1 else 1,
stride_out=stride_out,
)
elif self.image_encoder_type == "skip_conv_kernel2":
return SkipConvBackbone(
image_encoder_params.dim_in,
image_encoder_params.dim_out,
kernel_size=stride_out,
stride_out=stride_out,
)
else:
raise ValueError(f"Unsupported image encoder type: {self.image_encoder_type}")
def forward(self, input_features: torch.Tensor, encodings: list[torch.Tensor]) -> ImageFeatures:
"""Run monodepth and fuse features with input image to predict Gaussians.
Args:
input_features: The input features to use.
encodings: Feature encodings (e.g. from monodepth network).
"""
features = self.decoder(encodings).contiguous()
features = self.upsample(features)
if self.use_depth_input:
skip_features = self.image_encoder(input_features).texture_features
else:
skip_features = self.image_encoder(input_features[:, :3].contiguous())
features = self.fusion(features, skip_features)
texture_features = self.texture_head(features)
geometry_features = self.geometry_head(features)
return ImageFeatures(
texture_features=texture_features, # type: ignore
geometry_features=geometry_features, # type: ignore
)
@property
def stride(self) -> int:
"""Internal stride of GaussianDensePredictionTransformer."""
return self.stride_out
|