File size: 6,275 Bytes
c20d7cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
"""Contains reusable network components.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from typing import Literal
import torch
from torch import nn
NormLayerName = Literal["noop", "batch_norm", "group_norm", "instance_norm"]
UpsamplingMode = Literal["transposed_conv", "nearest", "bilinear"]
def norm_layer_2d(num_features: int, norm_type: NormLayerName, num_groups: int = 8) -> nn.Module:
"""Create normalization layer."""
if norm_type == "noop":
return nn.Identity()
elif norm_type == "batch_norm":
return nn.BatchNorm2d(num_features=num_features)
elif norm_type == "group_norm":
return nn.GroupNorm(num_channels=num_features, num_groups=num_groups)
elif norm_type == "instance_norm":
return nn.InstanceNorm2d(num_features=num_features)
else:
raise ValueError(f"Invalid normalization layer type: {norm_type}")
def upsampling_layer(upsampling_mode: UpsamplingMode, scale_factor: int, dim_in: int) -> nn.Module:
"""Create upsampling layer."""
if upsampling_mode == "transposed_conv":
return nn.ConvTranspose2d(
in_channels=dim_in,
out_channels=dim_in,
kernel_size=scale_factor,
stride=scale_factor,
padding=0,
bias=False,
)
elif upsampling_mode in ("nearest", "bilinear"):
return nn.Upsample(scale_factor=scale_factor, mode=upsampling_mode)
else:
raise ValueError(f"Invalid upsampling mode {upsampling_mode}.")
class ResidualBlock(nn.Module):
"""Generic implementation of residual blocks.
This implements a generic residual block from
He et al. - Identity Mappings in Deep Residual Networks (2016),
https://arxiv.org/abs/1603.05027
which can be further customized via factory functions.
"""
def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
"""Initialize ResidualBlock."""
super().__init__()
self.residual = residual
self.shortcut = shortcut
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply residual block."""
delta_x = self.residual(x)
if self.shortcut is not None:
x = self.shortcut(x)
return x + delta_x
def residual_block_2d(
dim_in: int,
dim_out: int,
dim_hidden: int | None = None,
actvn: nn.Module | None = None,
norm_type: NormLayerName = "noop",
norm_num_groups: int = 8,
dilation: int = 1,
kernel_size: int = 3,
):
"""Create a simple 2D residual block."""
if actvn is None:
actvn = nn.ReLU()
if dim_hidden is None:
dim_hidden = dim_out // 2
# Padding to maintain output size
# See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
padding = (dilation * (kernel_size - 1)) // 2
def _create_block(dim_in: int, dim_out: int) -> list[nn.Module]:
layers = [
norm_layer_2d(dim_in, norm_type, num_groups=norm_num_groups),
actvn,
]
layers.append(
nn.Conv2d(
dim_in,
dim_out,
kernel_size=kernel_size,
stride=1,
dilation=dilation,
padding=padding,
)
)
return layers
residual = nn.Sequential(
*_create_block(dim_in, dim_hidden),
*_create_block(dim_hidden, dim_out),
)
shortcut = None
if dim_in != dim_out:
shortcut = nn.Conv2d(dim_in, dim_out, 1)
return ResidualBlock(residual, shortcut)
class FeatureFusionBlock2d(nn.Module):
"""Feature fusion for DPT."""
# We use the name "deconv" for backward compatibility. However, "deconv" can also
# refer to some other upsampling layer or a no-op.
deconv: nn.Module
def __init__(
self,
dim_in: int,
dim_out: int | None = None,
upsampling_mode: UpsamplingMode | None = None,
batch_norm: bool = False,
):
"""Initialize feature fusion block.
Args:
dim_in: Dimensions of input.
dim_out: Dimensions of output.
batch_norm: Whether to use batch normalization in resnet blocks.
upsampling_mode: What mode to use for upsampling. None if no upsampling
is required.
"""
super().__init__()
if dim_out is None:
dim_out = dim_in
self.resnet1 = self._residual_block(dim_in, batch_norm)
self.resnet2 = self._residual_block(dim_in, batch_norm)
if upsampling_mode is not None:
self.deconv = upsampling_layer(upsampling_mode, scale_factor=2, dim_in=dim_in)
else:
self.deconv = nn.Sequential()
self.out_conv = nn.Conv2d(
dim_in,
dim_out,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
"""Process and fuse input features."""
x = x0
if x1 is not None:
res = self.resnet1(x1)
x = self.skip_add.add(x, res)
x = self.resnet2(x)
x = self.deconv(x)
x = self.out_conv(x)
return x
@staticmethod
def _residual_block(num_features: int, batch_norm: bool):
"""Create a residual block."""
def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
layers = [
nn.ReLU(False),
nn.Conv2d(
num_features,
num_features,
kernel_size=3,
stride=1,
padding=1,
bias=not batch_norm,
),
]
if batch_norm:
layers.append(nn.BatchNorm2d(dim))
return layers
residual = nn.Sequential(
*_create_block(dim=num_features, batch_norm=batch_norm),
*_create_block(dim=num_features, batch_norm=batch_norm),
)
return ResidualBlock(residual)
|