File size: 3,591 Bytes
c20d7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Contains backbone models for feature extraction from RGBD input.

For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""

from __future__ import annotations

from typing import List

import torch
from torch import nn

from sharp.models.blocks import (
    NormLayerName,
    norm_layer_2d,
    residual_block_2d,
)

from .base_encoder import BaseEncoder


class UNetEncoder(BaseEncoder):
    """Encoder of UNet model."""

    def __init__(
        self,
        dim_in: int,
        width: List[int] | int,
        steps: int = 6,
        norm_type: NormLayerName = "group_norm",
        norm_num_groups=8,
        blocks_per_layer=2,
    ) -> None:
        """Initialize UNet Encoder.

        Args:
            dim_in: The number of input channels.
            width: Width multiplicator of intermediate layers or the width list of all layers.
            steps: The number of downsampling steps.
            norm_type: Which kind of normalization layer to use.
            norm_num_groups: How many groups to use for group norm (if relevant).
            blocks_per_layer: How many residual blocks per layer to use.
        """
        super().__init__()

        if blocks_per_layer < 1:
            raise ValueError("blocks_per_layer must be greater or equal to one.")

        self.dim_in = dim_in
        self.width = width
        self.num_steps = steps

        self.convs_down = nn.ModuleList()

        self.output_dims: list[int]
        # If only one number is specified, we assume each layer will double the channel dimension.
        if isinstance(width, int):
            self.output_dims = [width << i for i in range(0, steps + 1)]
        else:
            if len(width) != (steps + 1):
                raise ValueError("Length of width should match the steps for UNetEncoder.")
            self.output_dims = width

        self.conv_in = nn.Sequential(
            nn.Conv2d(self.dim_in, self.output_dims[0], 3, stride=1, padding=1),
            norm_layer_2d(self.output_dims[0], norm_type, num_groups=norm_num_groups),
            nn.ReLU(),
        )

        for i_step in range(steps):
            input_width = self.output_dims[i_step]
            current_width = self.output_dims[i_step + 1]
            convs_down_i = nn.Sequential(
                nn.AvgPool2d(2, stride=2),
                residual_block_2d(
                    input_width,
                    current_width,
                    norm_type=norm_type,
                    norm_num_groups=norm_num_groups,
                ),
                *[
                    residual_block_2d(
                        current_width,
                        current_width,
                        norm_type=norm_type,
                        norm_num_groups=norm_num_groups,
                    )
                    for _ in range(blocks_per_layer - 1)
                ],
            )
            self.convs_down.append(convs_down_i)

    def forward(self, input: torch.Tensor) -> list[torch.Tensor]:
        """Apply UNet Encoder to image.

        Args:
            input: The input image.

        Returns:
            The output multi-level feature map from encoder.
        """
        features = []

        feat_i = self.conv_in(input)
        features.append(feat_i)

        for conv_down in self.convs_down:
            feat_i = conv_down(feat_i)
            features.append(feat_i)

        return features

    @property
    def out_width(self) -> int:
        """Compute the output width for UNet decoder."""
        return self.output_dims[-1]