File size: 4,261 Bytes
c20d7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Contains Dense Transformer Prediction architecture.

Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413

For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""

from __future__ import annotations

import torch
import torch.nn as nn

from sharp.models.presets import (
    MONODEPTH_ENCODER_DIMS_MAP,
    MONODEPTH_HOOK_IDS_MAP,
    ViTPreset,
)

from .base_encoder import BaseEncoder
from .spn_encoder import SlidingPyramidNetwork
from .vit_encoder import create_vit


def create_monodepth_encoder(
    patch_encoder_preset: ViTPreset,
    image_encoder_preset: ViTPreset,
    use_patch_overlap: bool = True,
    last_encoder: int = 256,
) -> SlidingPyramidNetwork:
    """Creates DepthDensePredictionTransformer model.

    Args:
        patch_encoder_preset: The preset patch encoder architecture in SPN.
        image_encoder_preset: The preset image encoder architecture in SPN.
        use_patch_overlap: Whether to use overlap between patches in SPN.
        last_encoder: last number of encoder features.
    """
    dims_encoder = [last_encoder] + MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset]
    patch_encoder_block_ids = MONODEPTH_HOOK_IDS_MAP[patch_encoder_preset]

    patch_encoder = create_vit(
        preset=patch_encoder_preset,
        intermediate_features_ids=patch_encoder_block_ids,
        # We always need to output intermediate features for assembly.
    )
    image_encoder = create_vit(
        preset=image_encoder_preset,
        intermediate_features_ids=None,
    )

    encoder = SlidingPyramidNetwork(
        dims_encoder=dims_encoder,
        patch_encoder=patch_encoder,
        image_encoder=image_encoder,
        use_patch_overlap=use_patch_overlap,
    )

    return encoder


class ProjectionModule(nn.Module):
    """Apply projection of features."""

    def __init__(self, dims_in: list[int], dims_out: list[int]) -> None:
        """Initialize projection module."""
        super().__init__()
        if len(dims_in) != len(dims_out):
            raise ValueError("Length of dims_in must be same as length of dims_out.")
        self.convs = nn.ModuleList(
            [nn.Conv2d(dim_in, dim_out, 1) for dim_in, dim_out in zip(dims_in, dims_out)]
        )

    def forward(self, encodings: list[torch.Tensor]) -> list[torch.Tensor]:
        """Apply projection module."""
        if len(encodings) != len(self.convs):
            raise ValueError("Number of encodings must be equal to number of projections.")
        return [conv(encoding) for conv, encoding in zip(self.convs, encodings)]


class MonodepthFeatureEncoder(BaseEncoder):
    """A wrapper around monodepth network to extract features."""

    def __init__(
        self,
        monodepth_encoder: SlidingPyramidNetwork,
        output_dims: list[int] | None = None,
        freeze_projection: bool = False,
    ) -> None:
        """Initialize MonodepthFeatureExtractor."""
        super().__init__()

        self.encoder = monodepth_encoder

        # The monodepth network returns two feature maps for the first entry in
        # backbone.encoder.dims_encoder.
        monodepth_dims = self.encoder.dims_encoder
        monodepth_dims = monodepth_dims

        if output_dims is not None:
            if not len(output_dims) == len(monodepth_dims):
                raise ValueError(
                    "When set, number of output dimensions must be equal to output "
                    f"dimensions of monodepth model {len(monodepth_dims)}."
                )

            self.projection = ProjectionModule(monodepth_dims, output_dims)
            self.output_dims = output_dims
        else:
            self.projection = nn.Identity()
            self.output_dims = monodepth_dims

        if freeze_projection:
            self.projection.requires_grad_(False)

    def forward(self, input_features: torch.Tensor) -> list[torch.Tensor]:
        """Extract multi-resolution features."""
        encodings = self.encoder(input_features[:, :3].contiguous())
        return self.projection(encodings)

    def internal_resolution(self) -> int:
        """Internal resolution of the encoder."""
        return self.encoder.internal_resolution()