| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from functools import partial |
| from typing import List, Optional |
|
|
|
|
| class Conv2dAct(nn.Sequential): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| padding: int = 0, |
| stride: int = 1, |
| norm_layer: str = "bn", |
| num_groups: int = 32, |
| activation: str = "ReLU", |
| inplace: bool = True, |
| ): |
| if norm_layer == "bn": |
| NormLayer = nn.BatchNorm2d |
| elif norm_layer == "gn": |
| NormLayer = partial(nn.GroupNorm, num_groups=num_groups) |
| else: |
| raise Exception( |
| f"`norm_layer` must be one of [`bn`, `gn`], got `{norm_layer}`" |
| ) |
| super().__init__() |
| self.conv = nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| bias=False, |
| ) |
| self.norm = NormLayer(out_channels) |
| self.act = getattr(nn, activation)(inplace=inplace) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.act(self.norm(self.conv(x))) |
|
|
|
|
| class SCSEModule(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| reduction: int = 16, |
| activation: str = "ReLU", |
| inplace: bool = False, |
| ): |
| super().__init__() |
| self.cSE = nn.Sequential( |
| nn.AdaptiveAvgPool2d(1), |
| nn.Conv2d(in_channels, in_channels // reduction, 1), |
| getattr(nn, activation)(inplace=inplace), |
| nn.Conv2d(in_channels // reduction, in_channels, 1), |
| ) |
| self.sSE = nn.Conv2d(in_channels, 1, 1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x * self.cSE(x).sigmoid() + x * self.sSE(x).sigmoid() |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, name: str, **params): |
| super().__init__() |
|
|
| if name is None: |
| self.attention = nn.Identity(**params) |
| elif name == "scse": |
| self.attention = SCSEModule(**params) |
| else: |
| raise ValueError("Attention {} is not implemented".format(name)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.attention(x) |
|
|
|
|
| class DecoderBlock(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| skip_channels: int, |
| out_channels: int, |
| norm_layer: str = "bn", |
| activation: str = "ReLU", |
| attention_type: Optional[str] = None, |
| ): |
| super().__init__() |
| self.conv1 = Conv2dAct( |
| in_channels + skip_channels, |
| out_channels, |
| kernel_size=3, |
| padding=1, |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
| self.attention1 = Attention( |
| attention_type, in_channels=in_channels + skip_channels |
| ) |
| self.conv2 = Conv2dAct( |
| out_channels, |
| out_channels, |
| kernel_size=3, |
| padding=1, |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
| self.attention2 = Attention(attention_type, in_channels=out_channels) |
|
|
| def forward( |
| self, x: torch.Tensor, skip: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| if skip is not None: |
| h, w = skip.shape[2:] |
| x = F.interpolate(x, size=(h, w), mode="nearest") |
| x = torch.cat([x, skip], dim=1) |
| x = self.attention1(x) |
| else: |
| x = F.interpolate(x, scale_factor=(2, 2), mode="nearest") |
| x = self.conv1(x) |
| x = self.conv2(x) |
| x = self.attention2(x) |
| return x |
|
|
|
|
| class CenterBlock(nn.Sequential): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| norm_layer: str = "bn", |
| activation: str = "ReLU", |
| ): |
| conv1 = Conv2dAct( |
| in_channels, |
| out_channels, |
| kernel_size=3, |
| padding=1, |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
| conv2 = Conv2dAct( |
| out_channels, |
| out_channels, |
| kernel_size=3, |
| padding=1, |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
| super().__init__(conv1, conv2) |
|
|
|
|
| class UnetDecoder(nn.Module): |
| def __init__( |
| self, |
| decoder_n_blocks: int, |
| decoder_channels: List[int], |
| encoder_channels: List[int], |
| decoder_center_block: bool = False, |
| decoder_norm_layer: str = "bn", |
| decoder_attention_type: Optional[str] = None, |
| ): |
| super().__init__() |
|
|
| self.decoder_n_blocks = decoder_n_blocks |
| self.decoder_channels = decoder_channels |
| self.encoder_channels = encoder_channels |
| self.decoder_center_block = decoder_center_block |
| self.decoder_norm_layer = decoder_norm_layer |
| self.decoder_attention_type = decoder_attention_type |
|
|
| if self.decoder_n_blocks != len(self.decoder_channels): |
| raise ValueError( |
| "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( |
| self.decoder_n_blocks, len(self.decoder_channels) |
| ) |
| ) |
| |
| encoder_channels = encoder_channels[::-1] |
|
|
| |
| head_channels = encoder_channels[0] |
| in_channels = [head_channels] + list(self.decoder_channels[:-1]) |
| skip_channels = list(encoder_channels[1:]) + [0] |
| out_channels = self.decoder_channels |
|
|
| if self.decoder_center_block: |
| self.center = CenterBlock( |
| head_channels, head_channels, norm_layer=self.decoder_norm_layer |
| ) |
| else: |
| self.center = nn.Identity() |
|
|
| |
| kwargs = dict( |
| norm_layer=self.decoder_norm_layer, |
| attention_type=self.decoder_attention_type, |
| ) |
| blocks = [ |
| DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) |
| for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) |
| ] |
| self.blocks = nn.ModuleList(blocks) |
|
|
| def forward(self, features: List[torch.Tensor]) -> torch.Tensor: |
| features = features[::-1] |
|
|
| head = features[0] |
| skips = features[1:] |
|
|
| output = [self.center(head)] |
| for i, decoder_block in enumerate(self.blocks): |
| skip = skips[i] if i < len(skips) else None |
| output.append(decoder_block(output[-1], skip)) |
|
|
| return output |
|
|
|
|
| class SegmentationHead(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| size: int, |
| kernel_size: int = 3, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
| self.drop = nn.Dropout2d(p=dropout) |
| self.conv = nn.Conv2d( |
| in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 |
| ) |
| if isinstance(size, (tuple, list)): |
| self.up = nn.Upsample(size=size, mode="bilinear") |
| else: |
| self.up = nn.Identity() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.up(self.conv(self.drop(x))) |
|
|