ml-sharp / src /sharp /models /gaussian_decoder.py
amael-apple's picture
Initial commit
c20d7cc
raw
history blame
9.2 kB
"""Contains Dense Transformer Prediction architecture.
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from typing import NamedTuple
import torch
import torch.nn as nn
from sharp.models.blocks import (
FeatureFusionBlock2d,
NormLayerName,
residual_block_2d,
)
from sharp.models.decoders import BaseDecoder, MultiresConvDecoder
from sharp.models.params import DPTImageEncoderType, GaussianDecoderParams
def create_gaussian_decoder(
params: GaussianDecoderParams, dims_depth_features: list[int]
) -> GaussianDensePredictionTransformer:
"""Create gaussian_decoder model specified by gaussian_decoder_name."""
decoder = MultiresConvDecoder(
dims_depth_features,
params.dims_decoder,
grad_checkpointing=params.grad_checkpointing,
upsampling_mode=params.upsampling_mode,
)
return GaussianDensePredictionTransformer(
decoder=decoder,
dim_in=params.dim_in,
dim_out=params.dim_out,
stride_out=params.stride,
norm_type=params.norm_type,
norm_num_groups=params.norm_num_groups,
use_depth_input=params.use_depth_input,
grad_checkpointing=params.grad_checkpointing,
image_encoder_type=params.image_encoder_type,
image_encoder_params=params,
)
def _create_project_upsample_block(
dim_in: int,
dim_out: int,
upsample_layers: int,
dim_intermediate: int | None = None,
) -> nn.Module:
if dim_intermediate is None:
dim_intermediate = dim_out
# Projection.
blocks = [
nn.Conv2d(
in_channels=dim_in,
out_channels=dim_intermediate,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
]
# Upsampling.
blocks += [
nn.ConvTranspose2d(
in_channels=dim_intermediate if i == 0 else dim_out,
out_channels=dim_out,
kernel_size=2,
stride=2,
padding=0,
bias=False,
)
for i in range(upsample_layers)
]
return nn.Sequential(*blocks)
class ImageFeatures(NamedTuple):
"""Image feature extracted from decoder."""
texture_features: torch.Tensor
geometry_features: torch.Tensor
class SkipConvBackbone(nn.Module):
"""A wrapper around a conv layer that behaves like a BaseBackbone."""
def __init__(self, dim_in: int, dim_out: int, kernel_size: int, stride_out: int):
"""Initialize SkipConvBackbone."""
super().__init__()
self.stride_out = stride_out
if stride_out == 1 and kernel_size != 1:
raise ValueError("We only support kernel_size = 1 if stride_out is 1.")
padding: int = (kernel_size - 1) // 2
self.conv = nn.Conv2d(
dim_in, dim_out, kernel_size=kernel_size, stride=stride_out, padding=padding
)
def forward(
self,
input_features: torch.Tensor,
encodings: list[torch.Tensor] | None = None,
) -> ImageFeatures:
"""Apply SkipConvBackbone to image."""
output = self.conv(input_features)
return ImageFeatures(
texture_features=output,
geometry_features=output,
)
@property
def stride(self) -> int:
"""Effective downsampling stride."""
return self.stride_out
class GaussianDensePredictionTransformer(nn.Module):
"""Dense Prediction Transformer for Gaussian.
Reuse monodepth decoded features for processing.
"""
norm_type: NormLayerName
def __init__(
self,
decoder: BaseDecoder,
dim_in: int,
dim_out: int,
stride_out: int,
image_encoder_params: GaussianDecoderParams,
image_encoder_type: DPTImageEncoderType = "skip_conv",
norm_type: NormLayerName = "group_norm",
norm_num_groups: int = 8,
use_depth_input: bool = True,
grad_checkpointing: bool = False,
):
"""Initialize Dense Prediction Transformer for Gaussian.
Args:
decoder: Decoder to decode features.
monodepth_decoder: Optional monodepth decoder to fuse monodepth decoded features.
dim_in: Input dimension.
dim_out: Final output dimension.
stride_out: Stride of output feature map.
image_encoder_params: The backbone parameters to configurate the image encoder.
image_encoder_type: Type of image encoder to use.
encoder: Encoder to generate features using monodepth model.
norm_type: Type of norm layers.
norm_num_groups: Num groups for norm layers.
use_depth_input: Whether to use depth input.
grad_checkpointing: Whether to use gradient checkpointing.
"""
super().__init__()
self.decoder = decoder
self.dim_in = dim_in
self.dim_out = dim_out
self.stride_out = stride_out
self.norm_type = norm_type
self.norm_num_groups = norm_num_groups
self.use_depth_input = use_depth_input
self.grad_checkpointing = grad_checkpointing
self.image_encoder_type = image_encoder_type
# Adopt an image encoder to lift dimension to monodepth feature and
# resize to be the same resolution as the decoder output.
dim_in = self.dim_in if use_depth_input else self.dim_in - 1
image_encoder_params.dim_in = dim_in
image_encoder_params.dim_out = decoder.dim_out
self.image_encoder = self._create_image_encoder(image_encoder_params, stride_out)
self.fusion = FeatureFusionBlock2d(decoder.dim_out)
if stride_out == 1:
self.upsample = _create_project_upsample_block(
decoder.dim_out,
decoder.dim_out,
upsample_layers=1,
)
elif stride_out == 2:
self.upsample = nn.Identity()
else:
raise ValueError("We only support stride is 1 or 2 for DPT backbone.")
self.texture_head = self._create_head(dim_decoder=decoder.dim_out, dim_out=self.dim_out)
self.geometry_head = self._create_head(dim_decoder=decoder.dim_out, dim_out=self.dim_out)
def _create_head(self, dim_decoder: int, dim_out: int) -> nn.Module:
return nn.Sequential(
residual_block_2d(
dim_in=dim_decoder,
dim_out=dim_decoder,
dim_hidden=dim_decoder // 2,
norm_type=self.norm_type,
norm_num_groups=self.norm_num_groups,
),
residual_block_2d(
dim_in=dim_decoder,
dim_hidden=dim_decoder // 2,
dim_out=dim_decoder,
norm_type=self.norm_type,
norm_num_groups=self.norm_num_groups,
),
nn.ReLU(),
nn.Conv2d(dim_decoder, dim_out, kernel_size=1, stride=1),
nn.ReLU(),
)
def _create_image_encoder(
self, image_encoder_params: GaussianDecoderParams, stride_out: int
) -> nn.Module:
"""Create image encoder and return based on parameters."""
if self.image_encoder_type == "skip_conv":
# Use kernel_size = 1 only if stride_out is 1.
return SkipConvBackbone(
image_encoder_params.dim_in,
image_encoder_params.dim_out,
kernel_size=3 if stride_out != 1 else 1,
stride_out=stride_out,
)
elif self.image_encoder_type == "skip_conv_kernel2":
return SkipConvBackbone(
image_encoder_params.dim_in,
image_encoder_params.dim_out,
kernel_size=stride_out,
stride_out=stride_out,
)
else:
raise ValueError(f"Unsupported image encoder type: {self.image_encoder_type}")
def forward(self, input_features: torch.Tensor, encodings: list[torch.Tensor]) -> ImageFeatures:
"""Run monodepth and fuse features with input image to predict Gaussians.
Args:
input_features: The input features to use.
encodings: Feature encodings (e.g. from monodepth network).
"""
features = self.decoder(encodings).contiguous()
features = self.upsample(features)
if self.use_depth_input:
skip_features = self.image_encoder(input_features).texture_features
else:
skip_features = self.image_encoder(input_features[:, :3].contiguous())
features = self.fusion(features, skip_features)
texture_features = self.texture_head(features)
geometry_features = self.geometry_head(features)
return ImageFeatures(
texture_features=texture_features, # type: ignore
geometry_features=geometry_features, # type: ignore
)
@property
def stride(self) -> int:
"""Internal stride of GaussianDensePredictionTransformer."""
return self.stride_out