| import torch.nn as nn |
| import torch |
|
|
|
|
| class ResidualConv(nn.Module): |
| def __init__(self, input_dim, output_dim, stride, padding): |
| super(ResidualConv, self).__init__() |
|
|
| self.conv_block = nn.Sequential( |
| nn.BatchNorm2d(input_dim), |
| nn.ReLU(), |
| nn.Conv2d( |
| input_dim, output_dim, kernel_size=3, stride=stride, padding=padding |
| ), |
| nn.BatchNorm2d(output_dim), |
| nn.ReLU(), |
| nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), |
| ) |
| self.conv_skip = nn.Sequential( |
| nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), |
| nn.BatchNorm2d(output_dim), |
| ) |
|
|
| def forward(self, x): |
|
|
| return self.conv_block(x) + self.conv_skip(x) |
|
|
|
|
| class Upsample(nn.Module): |
| def __init__(self, input_dim, output_dim, kernel, stride): |
| super(Upsample, self).__init__() |
|
|
| self.upsample = nn.ConvTranspose2d( |
| input_dim, output_dim, kernel_size=kernel, stride=stride |
| ) |
|
|
| def forward(self, x): |
| return self.upsample(x) |
|
|
|
|
| class Squeeze_Excite_Block(nn.Module): |
| def __init__(self, channel, reduction=16): |
| super(Squeeze_Excite_Block, self).__init__() |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| self.fc = nn.Sequential( |
| nn.Linear(channel, channel // reduction, bias=False), |
| nn.ReLU(inplace=True), |
| nn.Linear(channel // reduction, channel, bias=False), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward(self, x): |
| b, c, _, _ = x.size() |
| y = self.avg_pool(x).view(b, c) |
| y = self.fc(y).view(b, c, 1, 1) |
| return x * y.expand_as(x) |
|
|
|
|
| class ASPP(nn.Module): |
| def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): |
| super(ASPP, self).__init__() |
|
|
| self.aspp_block1 = nn.Sequential( |
| nn.Conv2d( |
| in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] |
| ), |
| nn.ReLU(inplace=True), |
| nn.BatchNorm2d(out_dims), |
| ) |
| self.aspp_block2 = nn.Sequential( |
| nn.Conv2d( |
| in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] |
| ), |
| nn.ReLU(inplace=True), |
| nn.BatchNorm2d(out_dims), |
| ) |
| self.aspp_block3 = nn.Sequential( |
| nn.Conv2d( |
| in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] |
| ), |
| nn.ReLU(inplace=True), |
| nn.BatchNorm2d(out_dims), |
| ) |
|
|
| self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) |
| self._init_weights() |
|
|
| def forward(self, x): |
| x1 = self.aspp_block1(x) |
| x2 = self.aspp_block2(x) |
| x3 = self.aspp_block3(x) |
| out = torch.cat([x1, x2, x3], dim=1) |
| return self.output(out) |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight) |
| elif isinstance(m, nn.BatchNorm2d): |
| m.weight.data.fill_(1) |
| m.bias.data.zero_() |
|
|
|
|
| class Upsample_(nn.Module): |
| def __init__(self, scale=2): |
| super(Upsample_, self).__init__() |
|
|
| self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) |
|
|
| def forward(self, x): |
| return self.upsample(x) |
|
|
|
|
| class AttentionBlock(nn.Module): |
| def __init__(self, input_encoder, input_decoder, output_dim): |
| super(AttentionBlock, self).__init__() |
|
|
| self.conv_encoder = nn.Sequential( |
| nn.BatchNorm2d(input_encoder), |
| nn.ReLU(), |
| nn.Conv2d(input_encoder, output_dim, 3, padding=1), |
| nn.MaxPool2d(2, 2), |
| ) |
|
|
| self.conv_decoder = nn.Sequential( |
| nn.BatchNorm2d(input_decoder), |
| nn.ReLU(), |
| nn.Conv2d(input_decoder, output_dim, 3, padding=1), |
| ) |
|
|
| self.conv_attn = nn.Sequential( |
| nn.BatchNorm2d(output_dim), |
| nn.ReLU(), |
| nn.Conv2d(output_dim, 1, 1), |
| ) |
|
|
| def forward(self, x1, x2): |
| out = self.conv_encoder(x1) + self.conv_decoder(x2) |
| out = self.conv_attn(out) |
| return out * x2 |