File size: 3,645 Bytes
c20d7cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
"""Contains the UNet decoder.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from typing import List
import torch
from torch import nn
from sharp.models.blocks import (
NormLayerName,
norm_layer_2d,
residual_block_2d,
)
from .base_decoder import BaseDecoder
class UNetDecoder(BaseDecoder):
"""Decoder of UNet model."""
def __init__(
self,
dim_out: int,
width: List[int] | int,
steps: int = 5,
norm_type: NormLayerName = "group_norm",
norm_num_groups=8,
blocks_per_layer=2,
) -> None:
"""Initialize UNet Decoder.
Args:
dim_out: The number of output channels.
width: Width of last input feature map from encoder
or the width list of all input feature maps from encoder.
steps: The number of upsampling steps.
norm_type: Which kind of normalization layer to use.
norm_num_groups: How many groups to use for group norm (if relevant).
blocks_per_layer: How many blocks per layer to use.
"""
super().__init__()
if blocks_per_layer < 1:
raise ValueError("blocks_per_layer must be greater or equal to one.")
self.dim_out = dim_out
self.convs_up = nn.ModuleList()
self.output_dims: list[int]
# If only one number is specified, we assume each layer will double the channel dimension.
if isinstance(width, int):
self.input_dims = [width >> i for i in range(0, steps + 1)]
else:
self.input_dims = width[::-1][: steps + 1]
for i_step in range(steps):
input_width = self.input_dims[i_step]
current_width = self.input_dims[i_step + 1]
convs_up_i = nn.Sequential(
nn.Upsample(scale_factor=2),
residual_block_2d(
input_width * (1 if i_step == 0 else 2),
current_width,
norm_type=norm_type,
norm_num_groups=norm_num_groups,
),
*[
residual_block_2d(
current_width,
current_width,
norm_type=norm_type,
norm_num_groups=norm_num_groups,
)
for _ in range(blocks_per_layer - 1)
],
)
self.convs_up.append(convs_up_i)
input_width = 2 * current_width
current_width //= 2
last_width = self.input_dims[-1]
self.conv_out = nn.Sequential(
norm_layer_2d(last_width * 2, norm_type, num_groups=norm_num_groups),
nn.ReLU(),
nn.Conv2d(last_width * 2, dim_out, 1),
norm_layer_2d(dim_out, norm_type, num_groups=norm_num_groups),
nn.ReLU(),
)
def forward(self, features: list[torch.Tensor]) -> torch.Tensor:
"""Apply UNet to image.
Args:
features: The input multi-level feature map from encoder.
Returns:
The output feature map.
"""
i_feature_layer = len(features) - 1
out = self.convs_up[0](features[i_feature_layer])
i_feature_layer -= 1
for conv_up in self.convs_up[1:]: # type: ignore
out = conv_up(torch.cat([out, features[i_feature_layer]], dim=1))
i_feature_layer -= 1
out = self.conv_out(torch.cat([out, features[i_feature_layer]], dim=1))
return out
|