File size: 13,312 Bytes
b4d7ac8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn as nn
from monai.networks.layers.factories import Conv, Pad, Pool
from monai.networks.utils import icnr_init, pixelshuffle
from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
__all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]
class UpSample(nn.Sequential):
"""
Upsamples data by `scale_factor`.
Supported modes are:
- "deconv": uses a transposed convolution.
- "deconvgroup": uses a transposed group convolution.
- "nontrainable": uses :py:class:`torch.nn.Upsample`.
- "pixelshuffle": uses :py:class:`monai.networks.blocks.SubpixelUpsample`.
This operation will cause non-deterministic when ``mode`` is ``UpsampleMode.NONTRAINABLE``.
Please check the link below for more details:
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
This module can optionally take a pre-convolution
(often used to map the number of features from `in_channels` to `out_channels`).
"""
def __init__(
self,
spatial_dims: int,
in_channels: int | None = None,
out_channels: int | None = None,
scale_factor: Sequence[float] | float = 2,
kernel_size: Sequence[float] | float | None = None,
size: tuple[int] | int | None = None,
mode: UpsampleMode | str = UpsampleMode.DECONV,
pre_conv: nn.Module | str | None = "default",
interp_mode: str = InterpolateMode.LINEAR,
align_corners: bool | None = True,
bias: bool = True,
apply_pad_pool: bool = True,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions of the input image.
in_channels: number of channels of the input image.
out_channels: number of channels of the output image. Defaults to `in_channels`.
scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2.
kernel_size: kernel size used during transposed convolutions. Defaults to `scale_factor`.
size: spatial size of the output image.
Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``.
In torch.nn.functional.interpolate, only one of `size` or `scale_factor` should be defined,
thus if size is defined, `scale_factor` will not be used.
Defaults to None.
mode: {``"deconv"``, ``"deconvgroup"``, ``"nontrainable"``, ``"pixelshuffle"``}. Defaults to ``"deconv"``.
pre_conv: a conv block applied before upsampling. Defaults to "default".
When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when
Only used in the "nontrainable" or "pixelshuffle" mode.
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
Only used in the "nontrainable" mode.
If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation.
This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively.
The interpolation mode. Defaults to ``"linear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html
align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True.
Only used in the "nontrainable" mode.
bias: whether to have a bias term in the default preconv and deconv layers. Defaults to True.
apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the
size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`.
Only used in the "pixelshuffle" mode.
"""
super().__init__()
scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims)
up_mode = look_up_option(mode, UpsampleMode)
if not kernel_size:
kernel_size_ = scale_factor_
output_padding = padding = 0
else:
kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims)
padding = tuple((k - 1) // 2 for k in kernel_size_) # type: ignore
output_padding = tuple(s - 1 - (k - 1) % 2 for k, s in zip(kernel_size_, scale_factor_)) # type: ignore
if up_mode == UpsampleMode.DECONV:
if not in_channels:
raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.")
self.add_module(
"deconv",
Conv[Conv.CONVTRANS, spatial_dims](
in_channels=in_channels,
out_channels=out_channels or in_channels,
kernel_size=kernel_size_,
stride=scale_factor_,
padding=padding,
output_padding=output_padding,
bias=bias,
),
)
elif up_mode == UpsampleMode.DECONVGROUP:
if not in_channels:
raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.")
if out_channels is None:
out_channels = in_channels
groups = out_channels if in_channels % out_channels == 0 else 1
self.add_module(
"deconvgroup",
Conv[Conv.CONVTRANS, spatial_dims](
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size_,
stride=scale_factor_,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
),
)
elif up_mode == UpsampleMode.NONTRAINABLE:
if pre_conv == "default" and (out_channels != in_channels): # defaults to no conv if out_chns==in_chns
if not in_channels:
raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.")
self.add_module(
"preconv",
Conv[Conv.CONV, spatial_dims](
in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias
),
)
elif pre_conv is not None and pre_conv != "default":
self.add_module("preconv", pre_conv) # type: ignore
elif pre_conv is None and (out_channels != in_channels):
raise ValueError(
"in the nontrainable mode, if not setting pre_conv, out_channels should equal to in_channels."
)
interp_mode = InterpolateMode(interp_mode)
linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR]
if interp_mode in linear_mode: # choose mode based on dimensions
interp_mode = linear_mode[spatial_dims - 1]
self.add_module(
"upsample_non_trainable",
nn.Upsample(
size=size,
scale_factor=None if size else scale_factor_,
mode=interp_mode.value,
align_corners=align_corners,
),
)
elif up_mode == UpsampleMode.PIXELSHUFFLE:
self.add_module(
"pixelshuffle",
SubpixelUpsample(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
scale_factor=scale_factor_[0], # isotropic
conv_block=pre_conv,
apply_pad_pool=apply_pad_pool,
bias=bias,
),
)
else:
raise NotImplementedError(f"Unsupported upsampling mode {mode}.")
class SubpixelUpsample(nn.Module):
"""
Upsample via using a subpixel CNN. This module supports 1D, 2D and 3D input images.
The module is consisted with two parts. First of all, a convolutional layer is employed
to increase the number of channels into: ``in_channels * (scale_factor ** dimensions)``.
Secondly, a pixel shuffle manipulation is utilized to aggregates the feature maps from
low resolution space and build the super resolution space.
The first part of the module is not fixed, a sequential layers can be used to replace the
default single layer.
See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
Using a nEfficient Sub-Pixel Convolutional Neural Network."
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
The idea comes from:
https://arxiv.org/abs/1609.05158
The pixel shuffle mechanism refers to:
https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#torch.nn.PixelShuffle.
and:
https://github.com/pytorch/pytorch/pull/6340.
"""
def __init__(
self,
spatial_dims: int,
in_channels: int | None,
out_channels: int | None = None,
scale_factor: int = 2,
conv_block: nn.Module | str | None = "default",
apply_pad_pool: bool = True,
bias: bool = True,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions of the input image.
in_channels: number of channels of the input image.
out_channels: optional number of channels of the output image.
scale_factor: multiplier for spatial size. Defaults to 2.
conv_block: a conv block to extract feature maps before upsampling. Defaults to None.
- When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized.
- When ``conv_block`` is an ``nn.module``,
please ensure the output number of channels is divisible ``(scale_factor ** dimensions)``.
apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the
size of `scale_factor` with a stride of 1. This implements the nearest neighbour resize convolution
component of subpixel convolutions described in Aitken et al.
bias: whether to have a bias term in the default conv_block. Defaults to True.
"""
super().__init__()
if scale_factor <= 0:
raise ValueError(f"The `scale_factor` multiplier must be an integer greater than 0, got {scale_factor}.")
self.dimensions = spatial_dims
self.scale_factor = scale_factor
if conv_block == "default":
out_channels = out_channels or in_channels
if not out_channels:
raise ValueError("in_channels need to be specified.")
conv_out_channels = out_channels * (scale_factor**self.dimensions)
self.conv_block = Conv[Conv.CONV, self.dimensions](
in_channels=in_channels, out_channels=conv_out_channels, kernel_size=3, stride=1, padding=1, bias=bias
)
icnr_init(self.conv_block, self.scale_factor)
elif conv_block is None:
self.conv_block = nn.Identity()
else:
self.conv_block = conv_block
self.pad_pool: nn.Module = nn.Identity()
if apply_pad_pool:
pool_type = Pool[Pool.AVG, self.dimensions]
pad_type = Pad[Pad.CONSTANTPAD, self.dimensions]
self.pad_pool = nn.Sequential(
pad_type(padding=(self.scale_factor - 1, 0) * self.dimensions, value=0.0),
pool_type(kernel_size=self.scale_factor, stride=1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...).
"""
x = self.conv_block(x)
if x.shape[1] % (self.scale_factor**self.dimensions) != 0:
raise ValueError(
f"Number of channels after `conv_block` ({x.shape[1]}) must be evenly "
"divisible by scale_factor ** dimensions "
f"({self.scale_factor}^{self.dimensions}={self.scale_factor**self.dimensions})."
)
x = pixelshuffle(x, self.dimensions, self.scale_factor)
x = self.pad_pool(x)
return x
Upsample = UpSample
Subpixelupsample = SubpixelUpSample = SubpixelUpsample
|