|
|
"""Contains Dense Transformer Prediction architecture. |
|
|
|
|
|
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413 |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import copy |
|
|
from typing import NamedTuple, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from sharp.models import normalizers |
|
|
from sharp.models.decoders import MultiresConvDecoder, create_monodepth_decoder |
|
|
from sharp.models.encoders import ( |
|
|
SlidingPyramidNetwork, |
|
|
create_monodepth_encoder, |
|
|
) |
|
|
from sharp.utils import module_surgery |
|
|
|
|
|
from .params import MonodepthAdaptorParams, MonodepthParams |
|
|
|
|
|
DimsDecoder = Tuple[int, int, int, int, int] |
|
|
|
|
|
|
|
|
class MonodepthDensePredictionTransformer(nn.Module): |
|
|
"""Dense Prediction Transformer for monodepth. |
|
|
|
|
|
Attach the disparity prediction head for monodepth prediction. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
encoder: SlidingPyramidNetwork, |
|
|
decoder: MultiresConvDecoder, |
|
|
last_dims: tuple[int, int], |
|
|
): |
|
|
"""Initialize Dense Prediction Transformer. |
|
|
|
|
|
Args: |
|
|
encoder: The SlidingPyramidTransformer backbone. |
|
|
decoder: The MultiresConvDecoder decoder. |
|
|
last_dims: The dimension for the last convolution layers. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.normalizer = normalizers.AffineRangeNormalizer( |
|
|
input_range=(0, 1), output_range=(-1, 1) |
|
|
) |
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
|
|
|
dim_decoder = decoder.dim_out |
|
|
self.head = nn.Sequential( |
|
|
nn.Conv2d(dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1), |
|
|
nn.ConvTranspose2d( |
|
|
in_channels=dim_decoder // 2, |
|
|
out_channels=dim_decoder // 2, |
|
|
kernel_size=2, |
|
|
stride=2, |
|
|
padding=0, |
|
|
bias=True, |
|
|
), |
|
|
nn.Conv2d( |
|
|
dim_decoder // 2, |
|
|
last_dims[0], |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
), |
|
|
nn.ReLU(True), |
|
|
nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0), |
|
|
nn.ReLU(), |
|
|
) |
|
|
|
|
|
|
|
|
self.head[4].bias.data.fill_(0) |
|
|
|
|
|
self.grad_checkpointing = False |
|
|
|
|
|
@torch.jit.ignore |
|
|
def set_grad_checkpointing(self, is_enabled=True): |
|
|
"""Enable grad checkpointing.""" |
|
|
self.grad_checkpointing = is_enabled |
|
|
self.encoder.set_grad_checkpointing(self.grad_checkpointing) |
|
|
self.decoder.set_grad_checkpointing(self.grad_checkpointing) |
|
|
|
|
|
def forward(self, image: torch.Tensor) -> torch.Tensor: |
|
|
"""Decode by projection and fusion of multi-resolution encodings.""" |
|
|
encodings = self.encoder(self.normalizer(image)) |
|
|
num_encoder_features = len(self.encoder.dims_encoder) |
|
|
features = self.decoder(encodings[:num_encoder_features]) |
|
|
disparity = self.head(features) |
|
|
return disparity |
|
|
|
|
|
def internal_resolution(self) -> int: |
|
|
"""Return the internal image size of the network.""" |
|
|
return self.encoder.internal_resolution() |
|
|
|
|
|
|
|
|
def create_monodepth_dpt( |
|
|
params: MonodepthParams | None = None, |
|
|
) -> MonodepthDensePredictionTransformer: |
|
|
"""Creates DepthDensePredictionTransformer model. |
|
|
|
|
|
Args: |
|
|
params: Parameters of monodepth network. |
|
|
|
|
|
Returns: |
|
|
The configured monodepth DPT. |
|
|
""" |
|
|
if params is None: |
|
|
params = MonodepthParams() |
|
|
encoder: SlidingPyramidNetwork = create_monodepth_encoder( |
|
|
params.patch_encoder_preset, |
|
|
params.image_encoder_preset, |
|
|
use_patch_overlap=params.use_patch_overlap, |
|
|
last_encoder=params.dims_decoder[0], |
|
|
) |
|
|
|
|
|
decoder: MultiresConvDecoder = create_monodepth_decoder( |
|
|
params.patch_encoder_preset, params.dims_decoder |
|
|
) |
|
|
|
|
|
monodepth_model = MonodepthDensePredictionTransformer( |
|
|
encoder=encoder, decoder=decoder, last_dims=(32, 1) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
monodepth_model.requires_grad_(False) |
|
|
|
|
|
monodepth_model.encoder.set_requires_grad_( |
|
|
patch_encoder=params.unfreeze_patch_encoder, |
|
|
image_encoder=params.unfreeze_image_encoder, |
|
|
) |
|
|
monodepth_model.decoder.requires_grad_(params.unfreeze_decoder) |
|
|
monodepth_model.head.requires_grad_(params.unfreeze_head) |
|
|
|
|
|
if not params.unfreeze_norm_layers: |
|
|
module_surgery.freeze_norm_layer(monodepth_model) |
|
|
|
|
|
monodepth_model.set_grad_checkpointing(params.grad_checkpointing) |
|
|
|
|
|
return monodepth_model |
|
|
|
|
|
|
|
|
class MonodepthOutput(NamedTuple): |
|
|
"""Output of the monodepth model.""" |
|
|
|
|
|
|
|
|
disparity: torch.Tensor |
|
|
|
|
|
encoder_features: list[torch.Tensor] |
|
|
|
|
|
decoder_features: torch.Tensor |
|
|
|
|
|
output_features: list[torch.Tensor] |
|
|
|
|
|
intermediate_features: list[torch.Tensor] = [] |
|
|
|
|
|
|
|
|
class MonodepthWithEncodingAdaptor(nn.Module): |
|
|
"""Monodepth model with feature maps.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
monodepth_predictor: MonodepthDensePredictionTransformer, |
|
|
return_encoder_features: bool, |
|
|
return_decoder_features: bool, |
|
|
num_monodepth_layers: int, |
|
|
sorting_monodepth: bool, |
|
|
): |
|
|
"""Initialize MonodepthWithEncodingAdaptor. |
|
|
|
|
|
Args: |
|
|
monodepth_predictor: The monodepth model. |
|
|
return_encoder_features: Whether to return encoder features from monodepth model. |
|
|
return_decoder_features: Whether to return decoder features from monodepth model. |
|
|
num_monodepth_layers: How many layers the monodepth model predicts. |
|
|
sorting_monodepth: Whether to sort the monodepth output (for two layer monodepth). |
|
|
""" |
|
|
super().__init__() |
|
|
self.monodepth_predictor = monodepth_predictor |
|
|
self.return_encoder_features = return_encoder_features |
|
|
self.return_decoder_features = return_decoder_features |
|
|
self.num_monodepth_layers = num_monodepth_layers |
|
|
self.sorting_monodepth = sorting_monodepth |
|
|
|
|
|
def forward(self, image: torch.Tensor) -> MonodepthOutput: |
|
|
"""Process image and return disparity and feature maps.""" |
|
|
inputs = self.monodepth_predictor.normalizer(image) |
|
|
encoder_output = self.monodepth_predictor.encoder(inputs) |
|
|
|
|
|
num_encoder_features = len(self.monodepth_predictor.encoder.dims_encoder) |
|
|
|
|
|
|
|
|
|
|
|
encoder_features = encoder_output[:num_encoder_features] |
|
|
intermediate_features = encoder_output[num_encoder_features:] |
|
|
decoder_features = self.monodepth_predictor.decoder(encoder_features) |
|
|
disparity = self.monodepth_predictor.head(decoder_features) |
|
|
|
|
|
|
|
|
if self.num_monodepth_layers == 2 and self.sorting_monodepth: |
|
|
first_layer_disparity = disparity.max(dim=1, keepdims=True).values |
|
|
second_layer_disparity = disparity.min(dim=1, keepdims=True).values |
|
|
disparity = torch.cat([first_layer_disparity, second_layer_disparity], dim=1) |
|
|
|
|
|
output_features = [] |
|
|
if self.return_encoder_features: |
|
|
output_features.extend(encoder_features) |
|
|
|
|
|
if self.return_decoder_features: |
|
|
output_features.append(decoder_features) |
|
|
|
|
|
return MonodepthOutput( |
|
|
disparity=disparity, |
|
|
encoder_features=encoder_features, |
|
|
decoder_features=decoder_features, |
|
|
output_features=output_features, |
|
|
intermediate_features=intermediate_features, |
|
|
) |
|
|
|
|
|
def get_feature_dims(self) -> list[int]: |
|
|
"""Return dimensions of output feature maps.""" |
|
|
dims = [] |
|
|
if self.return_encoder_features: |
|
|
dims.extend(self.monodepth_predictor.encoder.dims_encoder) |
|
|
|
|
|
if self.return_decoder_features: |
|
|
dims.append(self.monodepth_predictor.decoder.dim_out) |
|
|
|
|
|
return dims |
|
|
|
|
|
def internal_resolution(self) -> int: |
|
|
"""Return the internal image size of the network.""" |
|
|
return self.monodepth_predictor.internal_resolution() |
|
|
|
|
|
def replicate_head(self, num_repeat: int): |
|
|
"""Replicate the last convolution layer (head[4] in DPT) for multi layer depth.""" |
|
|
conv_last = copy.deepcopy(self.monodepth_predictor.head[4]) |
|
|
self.monodepth_predictor.head[4].out_channels = num_repeat |
|
|
self.monodepth_predictor.head[4].weight = nn.Parameter( |
|
|
conv_last.weight.repeat(num_repeat, 1, 1, 1) |
|
|
) |
|
|
self.monodepth_predictor.head[4].bias = nn.Parameter(conv_last.bias.repeat(num_repeat)) |
|
|
|
|
|
|
|
|
def create_monodepth_adaptor( |
|
|
monodepth_predictor: MonodepthDensePredictionTransformer, |
|
|
params: MonodepthAdaptorParams, |
|
|
num_monodepth_layers: int, |
|
|
sorting_monodepth: bool, |
|
|
) -> MonodepthWithEncodingAdaptor: |
|
|
"""Create an adaptor that returns both disparity and features.""" |
|
|
adaptor = MonodepthWithEncodingAdaptor( |
|
|
monodepth_predictor=monodepth_predictor, |
|
|
return_encoder_features=params.encoder_features, |
|
|
return_decoder_features=params.decoder_features, |
|
|
num_monodepth_layers=num_monodepth_layers, |
|
|
sorting_monodepth=sorting_monodepth, |
|
|
) |
|
|
return adaptor |
|
|
|