ml-sharp / src /sharp /models /encoders /unet_encoder.py
amael-apple's picture
Initial commit
c20d7cc
"""Contains backbone models for feature extraction from RGBD input.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from typing import List
import torch
from torch import nn
from sharp.models.blocks import (
NormLayerName,
norm_layer_2d,
residual_block_2d,
)
from .base_encoder import BaseEncoder
class UNetEncoder(BaseEncoder):
"""Encoder of UNet model."""
def __init__(
self,
dim_in: int,
width: List[int] | int,
steps: int = 6,
norm_type: NormLayerName = "group_norm",
norm_num_groups=8,
blocks_per_layer=2,
) -> None:
"""Initialize UNet Encoder.
Args:
dim_in: The number of input channels.
width: Width multiplicator of intermediate layers or the width list of all layers.
steps: The number of downsampling steps.
norm_type: Which kind of normalization layer to use.
norm_num_groups: How many groups to use for group norm (if relevant).
blocks_per_layer: How many residual blocks per layer to use.
"""
super().__init__()
if blocks_per_layer < 1:
raise ValueError("blocks_per_layer must be greater or equal to one.")
self.dim_in = dim_in
self.width = width
self.num_steps = steps
self.convs_down = nn.ModuleList()
self.output_dims: list[int]
# If only one number is specified, we assume each layer will double the channel dimension.
if isinstance(width, int):
self.output_dims = [width << i for i in range(0, steps + 1)]
else:
if len(width) != (steps + 1):
raise ValueError("Length of width should match the steps for UNetEncoder.")
self.output_dims = width
self.conv_in = nn.Sequential(
nn.Conv2d(self.dim_in, self.output_dims[0], 3, stride=1, padding=1),
norm_layer_2d(self.output_dims[0], norm_type, num_groups=norm_num_groups),
nn.ReLU(),
)
for i_step in range(steps):
input_width = self.output_dims[i_step]
current_width = self.output_dims[i_step + 1]
convs_down_i = nn.Sequential(
nn.AvgPool2d(2, stride=2),
residual_block_2d(
input_width,
current_width,
norm_type=norm_type,
norm_num_groups=norm_num_groups,
),
*[
residual_block_2d(
current_width,
current_width,
norm_type=norm_type,
norm_num_groups=norm_num_groups,
)
for _ in range(blocks_per_layer - 1)
],
)
self.convs_down.append(convs_down_i)
def forward(self, input: torch.Tensor) -> list[torch.Tensor]:
"""Apply UNet Encoder to image.
Args:
input: The input image.
Returns:
The output multi-level feature map from encoder.
"""
features = []
feat_i = self.conv_in(input)
features.append(feat_i)
for conv_down in self.convs_down:
feat_i = conv_down(feat_i)
features.append(feat_i)
return features
@property
def out_width(self) -> int:
"""Compute the output width for UNet decoder."""
return self.output_dims[-1]