|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from collections.abc import Sequence |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from monai.networks.blocks import Convolution, UpSample |
|
|
from monai.networks.layers.factories import Conv, Pool |
|
|
from monai.utils import ensure_tuple_rep |
|
|
|
|
|
__all__ = ["BasicUnet", "Basicunet", "basicunet", "BasicUNet"] |
|
|
|
|
|
|
|
|
class TwoConv(nn.Sequential): |
|
|
"""two convolutions.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_chns: int, |
|
|
out_chns: int, |
|
|
act: str | tuple, |
|
|
norm: str | tuple, |
|
|
bias: bool, |
|
|
dropout: float | tuple = 0.0, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. |
|
|
in_chns: number of input channels. |
|
|
out_chns: number of output channels. |
|
|
act: activation type and arguments. |
|
|
norm: feature normalization type and arguments. |
|
|
bias: whether to have a bias term in convolution blocks. |
|
|
dropout: dropout ratio. Defaults to no dropout. |
|
|
|
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
conv_0 = Convolution(spatial_dims, in_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1) |
|
|
conv_1 = Convolution( |
|
|
spatial_dims, out_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1 |
|
|
) |
|
|
self.add_module("conv_0", conv_0) |
|
|
self.add_module("conv_1", conv_1) |
|
|
|
|
|
|
|
|
class Down(nn.Sequential): |
|
|
"""maxpooling downsampling and two convolutions.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_chns: int, |
|
|
out_chns: int, |
|
|
act: str | tuple, |
|
|
norm: str | tuple, |
|
|
bias: bool, |
|
|
dropout: float | tuple = 0.0, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. |
|
|
in_chns: number of input channels. |
|
|
out_chns: number of output channels. |
|
|
act: activation type and arguments. |
|
|
norm: feature normalization type and arguments. |
|
|
bias: whether to have a bias term in convolution blocks. |
|
|
dropout: dropout ratio. Defaults to no dropout. |
|
|
|
|
|
""" |
|
|
super().__init__() |
|
|
max_pooling = Pool["MAX", spatial_dims](kernel_size=2) |
|
|
convs = TwoConv(spatial_dims, in_chns, out_chns, act, norm, bias, dropout) |
|
|
self.add_module("max_pooling", max_pooling) |
|
|
self.add_module("convs", convs) |
|
|
|
|
|
|
|
|
class UpCat(nn.Module): |
|
|
"""upsampling, concatenation with the encoder feature map, two convolutions""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int, |
|
|
in_chns: int, |
|
|
cat_chns: int, |
|
|
out_chns: int, |
|
|
act: str | tuple, |
|
|
norm: str | tuple, |
|
|
bias: bool, |
|
|
dropout: float | tuple = 0.0, |
|
|
upsample: str = "deconv", |
|
|
pre_conv: nn.Module | str | None = "default", |
|
|
interp_mode: str = "linear", |
|
|
align_corners: bool | None = True, |
|
|
halves: bool = True, |
|
|
is_pad: bool = True, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. |
|
|
in_chns: number of input channels to be upsampled. |
|
|
cat_chns: number of channels from the encoder. |
|
|
out_chns: number of output channels. |
|
|
act: activation type and arguments. |
|
|
norm: feature normalization type and arguments. |
|
|
bias: whether to have a bias term in convolution blocks. |
|
|
dropout: dropout ratio. Defaults to no dropout. |
|
|
upsample: upsampling mode, available options are |
|
|
``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. |
|
|
pre_conv: a conv block applied before upsampling. |
|
|
Only used in the "nontrainable" or "pixelshuffle" mode. |
|
|
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} |
|
|
Only used in the "nontrainable" mode. |
|
|
align_corners: set the align_corners parameter for upsample. Defaults to True. |
|
|
Only used in the "nontrainable" mode. |
|
|
halves: whether to halve the number of channels during upsampling. |
|
|
This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`. |
|
|
is_pad: whether to pad upsampling features to fit features from encoder. Defaults to True. |
|
|
|
|
|
""" |
|
|
super().__init__() |
|
|
if upsample == "nontrainable" and pre_conv is None: |
|
|
up_chns = in_chns |
|
|
else: |
|
|
up_chns = in_chns // 2 if halves else in_chns |
|
|
self.upsample = UpSample( |
|
|
spatial_dims, |
|
|
in_chns, |
|
|
up_chns, |
|
|
2, |
|
|
mode=upsample, |
|
|
pre_conv=pre_conv, |
|
|
interp_mode=interp_mode, |
|
|
align_corners=align_corners, |
|
|
) |
|
|
self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout) |
|
|
self.is_pad = is_pad |
|
|
|
|
|
def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
x: features to be upsampled. |
|
|
x_e: optional features from the encoder, if None, this branch is not in use. |
|
|
""" |
|
|
x_0 = self.upsample(x) |
|
|
|
|
|
if x_e is not None and torch.jit.isinstance(x_e, torch.Tensor): |
|
|
if self.is_pad: |
|
|
|
|
|
dimensions = len(x.shape) - 2 |
|
|
sp = [0] * (dimensions * 2) |
|
|
for i in range(dimensions): |
|
|
if x_e.shape[-i - 1] != x_0.shape[-i - 1]: |
|
|
sp[i * 2 + 1] = 1 |
|
|
x_0 = torch.nn.functional.pad(x_0, sp, "replicate") |
|
|
x = self.convs(torch.cat([x_e, x_0], dim=1)) |
|
|
else: |
|
|
x = self.convs(x_0) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class BasicUNet(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spatial_dims: int = 3, |
|
|
in_channels: int = 1, |
|
|
out_channels: int = 2, |
|
|
features: Sequence[int] = (32, 32, 64, 128, 256, 32), |
|
|
act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), |
|
|
norm: str | tuple = ("instance", {"affine": True}), |
|
|
bias: bool = True, |
|
|
dropout: float | tuple = 0.0, |
|
|
upsample: str = "deconv", |
|
|
): |
|
|
""" |
|
|
A UNet implementation with 1D/2D/3D supports. |
|
|
|
|
|
Based on: |
|
|
|
|
|
Falk et al. "U-Net – Deep Learning for Cell Counting, Detection, and |
|
|
Morphometry". Nature Methods 16, 67–70 (2019), DOI: |
|
|
http://dx.doi.org/10.1038/s41592-018-0261-2 |
|
|
|
|
|
Args: |
|
|
spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs. |
|
|
in_channels: number of input channels. Defaults to 1. |
|
|
out_channels: number of output channels. Defaults to 2. |
|
|
features: six integers as numbers of features. |
|
|
Defaults to ``(32, 32, 64, 128, 256, 32)``, |
|
|
|
|
|
- the first five values correspond to the five-level encoder feature sizes. |
|
|
- the last value corresponds to the feature size after the last upsampling. |
|
|
|
|
|
act: activation type and arguments. Defaults to LeakyReLU. |
|
|
norm: feature normalization type and arguments. Defaults to instance norm. |
|
|
bias: whether to have a bias term in convolution blocks. Defaults to True. |
|
|
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_, |
|
|
if a conv layer is directly followed by a batch norm layer, bias should be False. |
|
|
dropout: dropout ratio. Defaults to no dropout. |
|
|
upsample: upsampling mode, available options are |
|
|
``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. |
|
|
|
|
|
Examples:: |
|
|
|
|
|
# for spatial 2D |
|
|
>>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128)) |
|
|
|
|
|
# for spatial 2D, with group norm |
|
|
>>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) |
|
|
|
|
|
# for spatial 3D |
|
|
>>> net = BasicUNet(spatial_dims=3, features=(32, 32, 64, 128, 256, 32)) |
|
|
|
|
|
See Also |
|
|
|
|
|
- :py:class:`monai.networks.nets.DynUNet` |
|
|
- :py:class:`monai.networks.nets.UNet` |
|
|
|
|
|
""" |
|
|
super().__init__() |
|
|
fea = ensure_tuple_rep(features, 6) |
|
|
print(f"BasicUNet features: {fea}.") |
|
|
|
|
|
self.conv_0 = TwoConv(spatial_dims, in_channels, features[0], act, norm, bias, dropout) |
|
|
self.down_1 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout) |
|
|
self.down_2 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout) |
|
|
self.down_3 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout) |
|
|
self.down_4 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout) |
|
|
|
|
|
self.upcat_4 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample) |
|
|
self.upcat_3 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample) |
|
|
self.upcat_2 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample) |
|
|
self.upcat_1 = UpCat(spatial_dims, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False) |
|
|
|
|
|
self.final_conv = Conv["conv", spatial_dims](fea[5], out_channels, kernel_size=1) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
x: input should have spatially N dimensions |
|
|
``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `spatial_dims`. |
|
|
It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have |
|
|
even edge lengths. |
|
|
|
|
|
Returns: |
|
|
A torch Tensor of "raw" predictions in shape |
|
|
``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``. |
|
|
""" |
|
|
x0 = self.conv_0(x) |
|
|
|
|
|
x1 = self.down_1(x0) |
|
|
x2 = self.down_2(x1) |
|
|
x3 = self.down_3(x2) |
|
|
x4 = self.down_4(x3) |
|
|
|
|
|
u4 = self.upcat_4(x4, x3) |
|
|
u3 = self.upcat_3(u4, x2) |
|
|
u2 = self.upcat_2(u3, x1) |
|
|
u1 = self.upcat_1(u2, x0) |
|
|
|
|
|
logits = self.final_conv(u1) |
|
|
return logits |
|
|
|
|
|
|
|
|
BasicUnet = Basicunet = basicunet = BasicUNet |
|
|
|