Spaces:
Running
Running
| from typing import Sequence, Union | |
| import torch | |
| import torch.nn as nn | |
| from monai.networks.blocks.convolutions import Convolution, ResidualUnit | |
| from monai.networks.layers.factories import Act, Norm | |
| from monai.networks.layers.simplelayers import SkipConnection | |
| from monai.utils import alias, export | |
| class UNet_single(nn.Module): | |
| def __init__( | |
| self, | |
| dimensions: int, | |
| in_channels: int, | |
| out_channels: int, | |
| channels: Sequence[int], | |
| strides: Sequence[int], | |
| kernel_size: Union[Sequence[int], int] = 3, | |
| up_kernel_size: Union[Sequence[int], int] = 3, | |
| num_res_units: int = 0, | |
| act=Act.PRELU, | |
| norm=Norm.INSTANCE, | |
| dropout=0.0,) -> None: | |
| super().__init__() | |
| self.dimensions = dimensions | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.channels = channels | |
| self.strides = strides | |
| self.kernel_size = kernel_size | |
| self.up_kernel_size = up_kernel_size | |
| self.num_res_units = num_res_units | |
| self.act = act | |
| self.norm = norm | |
| self.dropout = dropout | |
| def _create_block( | |
| inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool) -> nn.Sequential: | |
| """ | |
| Builds the UNet structure from the bottom up by recursing down to the bottom block, then creating sequential | |
| blocks containing the downsample path, a skip connection around the previous block, and the upsample path. | |
| Args: | |
| inc: number of input channels. | |
| outc: number of output channels. | |
| channels: sequence of channels. Top block first. | |
| strides: convolution stride. | |
| is_top: True if this is the top block. | |
| """ | |
| c = channels[0] | |
| s = strides[0] | |
| subblock: nn.Module | |
| if len(channels) > 2: | |
| subblock = _create_block(c, c, channels[1:], strides[1:], False) # continue recursion down | |
| upc = c * 2 | |
| else: | |
| # the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer | |
| subblock = self._get_bottom_layer(c, channels[1]) | |
| upc = c + channels[1] | |
| down = self._get_down_layer(inc, c, s, is_top) # create layer in downsampling path | |
| up = self._get_up_layer(upc, outc, s, is_top) # create layer in upsampling path | |
| return nn.Sequential(down, SkipConnection(subblock), up) | |
| self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True) | |
| self.activation = nn.Sigmoid() | |
| def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: | |
| """ | |
| Args: | |
| in_channels: number of input channels. | |
| out_channels: number of output channels. | |
| strides: convolution stride. | |
| is_top: True if this is the top block. | |
| """ | |
| if self.num_res_units > 0: | |
| return ResidualUnit( | |
| self.dimensions, | |
| in_channels, | |
| out_channels, | |
| strides=strides, | |
| kernel_size=self.kernel_size, | |
| subunits=self.num_res_units, | |
| act=self.act, | |
| norm=self.norm, | |
| dropout=self.dropout, | |
| ) | |
| return Convolution( | |
| self.dimensions, | |
| in_channels, | |
| out_channels, | |
| strides=strides, | |
| kernel_size=self.kernel_size, | |
| act=self.act, | |
| norm=self.norm, | |
| dropout=self.dropout, | |
| ) | |
| def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: | |
| """ | |
| Args: | |
| in_channels: number of input channels. | |
| out_channels: number of output channels. | |
| """ | |
| return self._get_down_layer(in_channels, out_channels, 1, False) | |
| def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: | |
| """ | |
| Args: | |
| in_channels: number of input channels. | |
| out_channels: number of output channels. | |
| strides: convolution stride. | |
| is_top: True if this is the top block. | |
| """ | |
| conv: Union[Convolution, nn.Sequential] | |
| conv = Convolution( | |
| self.dimensions, | |
| in_channels, | |
| out_channels, | |
| strides=strides, | |
| kernel_size=self.up_kernel_size, | |
| act=self.act, | |
| norm=self.norm, | |
| dropout=self.dropout, | |
| conv_only=is_top and self.num_res_units == 0, | |
| is_transposed=True, | |
| ) | |
| if self.num_res_units > 0: | |
| ru = ResidualUnit( | |
| self.dimensions, | |
| out_channels, | |
| out_channels, | |
| strides=1, | |
| kernel_size=self.kernel_size, | |
| subunits=1, | |
| act=self.act, | |
| norm=self.norm, | |
| dropout=self.dropout, | |
| last_conv_only=is_top, | |
| ) | |
| conv = nn.Sequential(conv, ru) | |
| return conv | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.activation(self.model(x)) | |
| def train_step(self, image, segment, criterion, segbox = None): | |
| forwarded = self.forward(image) | |
| return criterion(forwarded, segment) | |