ml-sharp / src /sharp /models /decoders /monodepth_decoder.py
amael-apple's picture
Initial commit
c20d7cc
"""Contains factory function for loading/creating monodepth decoder.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from sharp.models.presets import (
MONODEPTH_ENCODER_DIMS_MAP,
ViTPreset,
)
from .multires_conv_decoder import MultiresConvDecoder
def create_monodepth_decoder(
patch_encoder_preset: ViTPreset,
dims_decoder=None,
) -> MultiresConvDecoder:
"""Create DepthDensePredictionTransformer model.
Args:
patch_encoder_preset: The preset patch encoder architecture in SPN.
dims_decoder: The decoder architecture.
"""
dims_encoder = MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset]
if dims_decoder is None:
dims_decoder = dims_encoder[0]
if isinstance(dims_decoder, int):
dims_decoder = [dims_decoder]
decoder = MultiresConvDecoder(
dims_encoder=[dims_decoder[0]] + list(dims_encoder), dims_decoder=dims_decoder
)
return decoder