| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from collections import OrderedDict |
| from typing import Callable, Sequence, Type, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from monai.networks.layers.factories import Conv, Dropout, Norm, Pool |
|
|
|
|
| class _DenseLayer(nn.Sequential): |
| def __init__( |
| self, spatial_dims: int, in_channels: int, growth_rate: int, bn_size: int, dropout_prob: float |
| ) -> None: |
| """ |
| Args: |
| spatial_dims: number of spatial dimensions of the input image. |
| in_channels: number of the input channel. |
| growth_rate: how many filters to add each layer (k in paper). |
| bn_size: multiplicative factor for number of bottle neck layers. |
| (i.e. bn_size * k features in the bottleneck layer) |
| dropout_prob: dropout rate after each dense layer. |
| """ |
| super(_DenseLayer, self).__init__() |
|
|
| out_channels = bn_size * growth_rate |
| conv_type: Callable = Conv[Conv.CONV, spatial_dims] |
| norm_type: Callable = Norm[Norm.BATCH, spatial_dims] |
| dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims] |
|
|
| self.add_module("norm1", norm_type(in_channels)) |
| self.add_module("relu1", nn.ReLU(inplace=True)) |
| self.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False)) |
|
|
| self.add_module("norm2", norm_type(out_channels)) |
| self.add_module("relu2", nn.ReLU(inplace=True)) |
| self.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False)) |
|
|
| if dropout_prob > 0: |
| self.add_module("dropout", dropout_type(dropout_prob)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| new_features = super(_DenseLayer, self).forward(x) |
| return torch.cat([x, new_features], 1) |
|
|
|
|
| class _DenseBlock(nn.Sequential): |
| def __init__( |
| self, spatial_dims: int, layers: int, in_channels: int, bn_size: int, growth_rate: int, dropout_prob: float |
| ) -> None: |
| """ |
| Args: |
| spatial_dims: number of spatial dimensions of the input image. |
| layers: number of layers in the block. |
| in_channels: number of the input channel. |
| bn_size: multiplicative factor for number of bottle neck layers. |
| (i.e. bn_size * k features in the bottleneck layer) |
| growth_rate: how many filters to add each layer (k in paper). |
| dropout_prob: dropout rate after each dense layer. |
| """ |
| super(_DenseBlock, self).__init__() |
| for i in range(layers): |
| layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob) |
| in_channels += growth_rate |
| self.add_module("denselayer%d" % (i + 1), layer) |
|
|
|
|
| class _Transition(nn.Sequential): |
| def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None: |
| """ |
| Args: |
| spatial_dims: number of spatial dimensions of the input image. |
| in_channels: number of the input channel. |
| out_channels: number of the output classes. |
| """ |
| super(_Transition, self).__init__() |
|
|
| conv_type: Callable = Conv[Conv.CONV, spatial_dims] |
| norm_type: Callable = Norm[Norm.BATCH, spatial_dims] |
| pool_type: Callable = Pool[Pool.AVG, spatial_dims] |
|
|
| self.add_module("norm", norm_type(in_channels)) |
| self.add_module("relu", nn.ReLU(inplace=True)) |
| self.add_module("conv", conv_type(in_channels, out_channels, kernel_size=1, bias=False)) |
| self.add_module("pool", pool_type(kernel_size=2, stride=2)) |
|
|
|
|
| class DenseNet(nn.Module): |
| """ |
| Densenet based on: `Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993.pdf>`_. |
| Adapted from `PyTorch Hub 2D version |
| <https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_. |
| |
| Args: |
| spatial_dims: number of spatial dimensions of the input image. |
| in_channels: number of the input channel. |
| out_channels: number of the output classes. |
| init_features: number of filters in the first convolution layer. |
| growth_rate: how many filters to add each layer (k in paper). |
| block_config: how many layers in each pooling block. |
| bn_size: multiplicative factor for number of bottle neck layers. |
| (i.e. bn_size * k features in the bottleneck layer) |
| dropout_prob: dropout rate after each dense layer. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| out_channels: int, |
| init_features: int = 64, |
| growth_rate: int = 32, |
| block_config: Sequence[int] = (6, 12, 24, 16), |
| bn_size: int = 4, |
| dropout_prob: float = 0.0, |
| ) -> None: |
|
|
| super(DenseNet, self).__init__() |
|
|
| conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] |
| norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] |
| pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] |
| avg_pool_type: Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]] = Pool[ |
| Pool.ADAPTIVEAVG, spatial_dims |
| ] |
|
|
| self.features = nn.Sequential( |
| OrderedDict( |
| [ |
| ("conv0", conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)), |
| ("norm0", norm_type(init_features)), |
| ("relu0", nn.ReLU(inplace=True)), |
| ("pool0", pool_type(kernel_size=3, stride=2, padding=1)), |
| ] |
| ) |
| ) |
|
|
| in_channels = init_features |
| for i, num_layers in enumerate(block_config): |
| block = _DenseBlock( |
| spatial_dims=spatial_dims, |
| layers=num_layers, |
| in_channels=in_channels, |
| bn_size=bn_size, |
| growth_rate=growth_rate, |
| dropout_prob=dropout_prob, |
| ) |
| self.features.add_module(f"denseblock{i + 1}", block) |
| in_channels += num_layers * growth_rate |
| if i == len(block_config) - 1: |
| self.features.add_module("norm5", norm_type(in_channels)) |
| else: |
| _out_channels = in_channels // 2 |
| trans = _Transition(spatial_dims, in_channels=in_channels, out_channels=_out_channels) |
| self.features.add_module(f"transition{i + 1}", trans) |
| in_channels = _out_channels |
|
|
| |
| self.class_layers = nn.Sequential( |
| OrderedDict( |
| [ |
| ("relu", nn.ReLU(inplace=True)), |
| ("norm", avg_pool_type(1)), |
| ("flatten", nn.Flatten(1)), |
| ("class", nn.Linear(in_channels, out_channels)), |
| ] |
| ) |
| ) |
|
|
| for m in self.modules(): |
| if isinstance(m, conv_type): |
| nn.init.kaiming_normal_(torch.as_tensor(m.weight)) |
| elif isinstance(m, norm_type): |
| nn.init.constant_(torch.as_tensor(m.weight), 1) |
| nn.init.constant_(torch.as_tensor(m.bias), 0) |
| elif isinstance(m, nn.Linear): |
| nn.init.constant_(torch.as_tensor(m.bias), 0) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.features(x) |
| x = self.class_layers(x) |
| return x |
|
|
|
|
| def densenet121(**kwargs) -> DenseNet: |
| model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) |
| return model |
|
|
|
|
| def densenet169(**kwargs) -> DenseNet: |
| model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) |
| return model |
|
|
|
|
| def densenet201(**kwargs) -> DenseNet: |
| model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) |
| return model |
|
|
|
|
| def densenet264(**kwargs) -> DenseNet: |
| model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 64, 48), **kwargs) |
| return model |
|
|