| """Contains factory function for loading/creating monodepth decoder. | |
| For licensing see accompanying LICENSE file. | |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. | |
| """ | |
| from __future__ import annotations | |
| from sharp.models.presets import ( | |
| MONODEPTH_ENCODER_DIMS_MAP, | |
| ViTPreset, | |
| ) | |
| from .multires_conv_decoder import MultiresConvDecoder | |
| def create_monodepth_decoder( | |
| patch_encoder_preset: ViTPreset, | |
| dims_decoder=None, | |
| ) -> MultiresConvDecoder: | |
| """Create DepthDensePredictionTransformer model. | |
| Args: | |
| patch_encoder_preset: The preset patch encoder architecture in SPN. | |
| dims_decoder: The decoder architecture. | |
| """ | |
| dims_encoder = MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset] | |
| if dims_decoder is None: | |
| dims_decoder = dims_encoder[0] | |
| if isinstance(dims_decoder, int): | |
| dims_decoder = [dims_decoder] | |
| decoder = MultiresConvDecoder( | |
| dims_encoder=[dims_decoder[0]] + list(dims_encoder), dims_decoder=dims_decoder | |
| ) | |
| return decoder | |