| """Contains backbone models for feature extraction from RGBD input. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import List |
|
|
| import torch |
| from torch import nn |
|
|
| from sharp.models.blocks import ( |
| NormLayerName, |
| norm_layer_2d, |
| residual_block_2d, |
| ) |
|
|
| from .base_encoder import BaseEncoder |
|
|
|
|
| class UNetEncoder(BaseEncoder): |
| """Encoder of UNet model.""" |
|
|
| def __init__( |
| self, |
| dim_in: int, |
| width: List[int] | int, |
| steps: int = 6, |
| norm_type: NormLayerName = "group_norm", |
| norm_num_groups=8, |
| blocks_per_layer=2, |
| ) -> None: |
| """Initialize UNet Encoder. |
| |
| Args: |
| dim_in: The number of input channels. |
| width: Width multiplicator of intermediate layers or the width list of all layers. |
| steps: The number of downsampling steps. |
| norm_type: Which kind of normalization layer to use. |
| norm_num_groups: How many groups to use for group norm (if relevant). |
| blocks_per_layer: How many residual blocks per layer to use. |
| """ |
| super().__init__() |
|
|
| if blocks_per_layer < 1: |
| raise ValueError("blocks_per_layer must be greater or equal to one.") |
|
|
| self.dim_in = dim_in |
| self.width = width |
| self.num_steps = steps |
|
|
| self.convs_down = nn.ModuleList() |
|
|
| self.output_dims: list[int] |
| |
| if isinstance(width, int): |
| self.output_dims = [width << i for i in range(0, steps + 1)] |
| else: |
| if len(width) != (steps + 1): |
| raise ValueError("Length of width should match the steps for UNetEncoder.") |
| self.output_dims = width |
|
|
| self.conv_in = nn.Sequential( |
| nn.Conv2d(self.dim_in, self.output_dims[0], 3, stride=1, padding=1), |
| norm_layer_2d(self.output_dims[0], norm_type, num_groups=norm_num_groups), |
| nn.ReLU(), |
| ) |
|
|
| for i_step in range(steps): |
| input_width = self.output_dims[i_step] |
| current_width = self.output_dims[i_step + 1] |
| convs_down_i = nn.Sequential( |
| nn.AvgPool2d(2, stride=2), |
| residual_block_2d( |
| input_width, |
| current_width, |
| norm_type=norm_type, |
| norm_num_groups=norm_num_groups, |
| ), |
| *[ |
| residual_block_2d( |
| current_width, |
| current_width, |
| norm_type=norm_type, |
| norm_num_groups=norm_num_groups, |
| ) |
| for _ in range(blocks_per_layer - 1) |
| ], |
| ) |
| self.convs_down.append(convs_down_i) |
|
|
| def forward(self, input: torch.Tensor) -> list[torch.Tensor]: |
| """Apply UNet Encoder to image. |
| |
| Args: |
| input: The input image. |
| |
| Returns: |
| The output multi-level feature map from encoder. |
| """ |
| features = [] |
|
|
| feat_i = self.conv_in(input) |
| features.append(feat_i) |
|
|
| for conv_down in self.convs_down: |
| feat_i = conv_down(feat_i) |
| features.append(feat_i) |
|
|
| return features |
|
|
| @property |
| def out_width(self) -> int: |
| """Compute the output width for UNet decoder.""" |
| return self.output_dims[-1] |
|
|