File size: 1,022 Bytes
c20d7cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
"""Contains factory function for loading/creating monodepth decoder.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from sharp.models.presets import (
MONODEPTH_ENCODER_DIMS_MAP,
ViTPreset,
)
from .multires_conv_decoder import MultiresConvDecoder
def create_monodepth_decoder(
patch_encoder_preset: ViTPreset,
dims_decoder=None,
) -> MultiresConvDecoder:
"""Create DepthDensePredictionTransformer model.
Args:
patch_encoder_preset: The preset patch encoder architecture in SPN.
dims_decoder: The decoder architecture.
"""
dims_encoder = MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset]
if dims_decoder is None:
dims_decoder = dims_encoder[0]
if isinstance(dims_decoder, int):
dims_decoder = [dims_decoder]
decoder = MultiresConvDecoder(
dims_encoder=[dims_decoder[0]] + list(dims_encoder), dims_decoder=dims_decoder
)
return decoder
|