File size: 4,108 Bytes
c20d7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Contains multi-res convolutional decoder.

Implements the decoder for Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413

For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""

from __future__ import annotations

from typing import Iterable

import torch
import torch.nn as nn

from sharp.models.blocks import FeatureFusionBlock2d, UpsamplingMode
from sharp.utils.training import checkpoint_wrapper

from .base_decoder import BaseDecoder


class MultiresConvDecoder(BaseDecoder):
    """Decoder for multi-resolution encodings."""

    def __init__(
        self,
        dims_encoder: Iterable[int],
        dims_decoder: Iterable[int] | int,
        grad_checkpointing: bool = False,
        upsampling_mode: UpsamplingMode = "transposed_conv",
    ):
        """Initialize multiresolution convolutional decoder.

        Args:
            dims_encoder: Expected dims at each level from the encoder.
            dims_decoder: Dim of decoder features.
            grad_checkpointing: Whether to checkpoint gradient during training.
            upsampling_mode: What method to use for upsampling.
        """
        super().__init__()
        self.dims_encoder = list(dims_encoder)

        if isinstance(dims_decoder, int):
            self.dims_decoder = [dims_decoder] * len(self.dims_encoder)
        else:
            self.dims_decoder = list(dims_decoder)

        if len(self.dims_decoder) != len(self.dims_encoder):
            raise ValueError("Received dims_encoder and dims_decoder of different sizes.")

        self.dim_out = self.dims_decoder[0]

        num_encoders = len(self.dims_encoder)

        # At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
        # when the dimensions mismatch. Otherwise we do not do anything, which is
        # the default behavior of monodepth.
        conv0 = (
            nn.Conv2d(self.dims_encoder[0], self.dims_decoder[0], kernel_size=1, bias=False)
            if self.dims_encoder[0] != self.dims_decoder[0]
            else nn.Identity()
        )

        convs = [conv0]
        for i in range(1, num_encoders):
            convs.append(
                nn.Conv2d(
                    self.dims_encoder[i],
                    self.dims_decoder[i],
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=False,
                )
            )
        self.convs = nn.ModuleList(convs)

        fusions = []
        for i in range(num_encoders):
            fusions.append(
                FeatureFusionBlock2d(
                    dim_in=self.dims_decoder[i],
                    dim_out=self.dims_decoder[i - 1] if i != 0 else self.dim_out,
                    upsampling_mode=upsampling_mode if i != 0 else None,
                    batch_norm=False,
                )
            )
        self.fusions = nn.ModuleList(fusions)

        self.grad_checkpointing = grad_checkpointing

    @torch.jit.ignore
    def set_grad_checkpointing(self, is_enabled=True):
        """Enable grad checkpointing."""
        self.grad_checkpointing = is_enabled

    def forward(self, encodings: list[torch.Tensor]) -> torch.Tensor:
        """Decode the multi-resolution encodings."""
        num_levels = len(encodings)
        num_encoders = len(self.dims_encoder)

        if num_levels != num_encoders:
            raise ValueError(
                f"Encoder output levels={num_levels} at runtime "
                f"mismatch with expected levels={num_encoders}."
            )

        # Project features of different encoder dims to the same decoder dim.
        # Fuse features from the lowest resolution (num_levels-1)
        # to the highest (0).
        features = self.convs[-1](encodings[-1])
        features = checkpoint_wrapper(self, self.fusions[-1], features)
        for i in range(num_levels - 2, -1, -1):
            features_i = self.convs[i](encodings[i])
            features = checkpoint_wrapper(self, self.fusions[i], features, features_i)
        return features