Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """Residual Attention Network module - feature extractor in AVRA.""" | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ResidualAttentionNet(nn.Module): | |
| def __init__(self, z=1): | |
| super(ResidualAttentionNet, self).__init__() | |
| num_filters = [8, 16, 32, 64, 128] | |
| k = 0 | |
| conv1 = nn.Sequential( | |
| nn.Conv2d(z, num_filters[k], kernel_size=7, stride=2, padding=3, bias=False), | |
| nn.BatchNorm2d(num_filters[k]), | |
| nn.ReLU(inplace=True) | |
| ) | |
| maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| resblock1 = ResidualModule(num_filters[k], num_filters[k+1]) | |
| k += 1 | |
| attention_module1 = AttentionModule(num_filters[k], num_filters[k], stage=1) | |
| resblock2 = ResidualModule(num_filters[k], num_filters[k+1], stride=2) | |
| k += 1 | |
| attention_module2 = AttentionModule(num_filters[k], num_filters[k], stage=2) | |
| resblock3 = ResidualModule(num_filters[k], num_filters[k+1], stride=2) | |
| k += 1 | |
| attention_module3 = AttentionModule(num_filters[k], num_filters[k], stage=3) | |
| resblock4 = ResidualModule(num_filters[k], num_filters[k+1], stride=2) | |
| k += 1 | |
| resblock5 = ResidualModule(num_filters[k], num_filters[k]) | |
| resblock6 = ResidualModule(num_filters[k], num_filters[k]) | |
| avgpoolblock = nn.Sequential( | |
| nn.BatchNorm2d(num_filters[k]), | |
| nn.ReLU(inplace=True), | |
| nn.AvgPool2d(kernel_size=3, stride=1) | |
| ) | |
| self.features = nn.Sequential( | |
| conv1, maxpool, | |
| resblock1, attention_module1, | |
| resblock2, attention_module2, | |
| resblock3, attention_module3, | |
| resblock4, resblock5, resblock6, | |
| avgpoolblock | |
| ) | |
| def forward(self, x): | |
| out = self.features(x) | |
| out = out.view(out.size(0), -1) | |
| return out | |
| class ResidualModule(nn.Module): | |
| def __init__(self, inplanes, planes, stride=1): | |
| super(ResidualModule, self).__init__() | |
| planes_4 = int(planes/4) | |
| self.inplanes = inplanes | |
| self.planes = planes | |
| self.stride = stride | |
| self.bn1 = nn.BatchNorm2d(inplanes) | |
| self.relu1 = nn.LeakyReLU() | |
| self.conv1 = nn.Conv2d(inplanes, planes_4, kernel_size=1, stride=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(planes_4) | |
| self.relu2 = nn.LeakyReLU() | |
| self.conv2 = nn.Conv2d(planes_4, planes_4, kernel_size=3, stride=stride, padding=1, bias=False) | |
| self.bn3 = nn.BatchNorm2d(planes_4) | |
| self.relu3 = nn.LeakyReLU() | |
| self.conv3 = nn.Conv2d(planes_4, planes, kernel_size=1, stride=1, bias=False) | |
| self.conv4 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) | |
| self.downsample = (self.inplanes != self.planes) or (self.stride != 1) | |
| def forward(self, x): | |
| residual = x | |
| out = self.bn1(x) | |
| out1 = self.relu1(out) | |
| out = self.conv1(out1) | |
| out = self.bn2(out) | |
| out = self.relu2(out) | |
| out = self.conv2(out) | |
| out = self.bn3(out) | |
| out = self.relu3(out) | |
| out = self.conv3(out) | |
| if self.downsample: | |
| residual = self.conv4(out1) | |
| out += residual | |
| return out | |
| class AttentionModule(nn.Module): | |
| def __init__(self, in_planes, out_planes, stage=1): | |
| super(AttentionModule, self).__init__() | |
| self.stage = stage | |
| self.res1 = ResidualModule(in_planes, out_planes) | |
| self.trunk_branch = nn.Sequential( | |
| ResidualModule(in_planes, out_planes), | |
| ResidualModule(in_planes, out_planes) | |
| ) | |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| if self.stage < 3: | |
| self.block1 = ResidualModule(in_planes, out_planes) | |
| self.skip1 = ResidualModule(in_planes, out_planes) | |
| self.block5 = ResidualModule(in_planes, out_planes) | |
| if self.stage == 1: | |
| self.block2 = ResidualModule(in_planes, out_planes) | |
| self.skip2 = ResidualModule(in_planes, out_planes) | |
| self.block4 = ResidualModule(in_planes, out_planes) | |
| self.block3 = nn.Sequential( | |
| ResidualModule(in_planes, out_planes), | |
| ResidualModule(in_planes, out_planes) | |
| ) | |
| self.block_sigmoid = nn.Sequential( | |
| nn.BatchNorm2d(out_planes), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_planes, out_planes, kernel_size=1, stride=1, bias=False), | |
| nn.BatchNorm2d(out_planes), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_planes, out_planes, kernel_size=1, stride=1, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| self.block6 = ResidualModule(in_planes, out_planes) | |
| def upsample(self, x): | |
| return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) | |
| def forward(self, x): | |
| x = self.res1(x) | |
| trunk_branch = self.trunk_branch(x) | |
| if self.stage < 3: | |
| x = self.maxpool(x) | |
| x = self.block1(x) | |
| skip1 = self.skip1(x) | |
| if self.stage == 1: | |
| x = self.maxpool(x) | |
| x = self.block2(x) | |
| skip2 = self.skip2(x) | |
| x = self.maxpool(x) | |
| x = self.block3(x) | |
| if self.stage == 1: | |
| x = self.upsample(x) | |
| x = x + skip2 | |
| x = self.block4(x) | |
| if self.stage < 3: | |
| x = self.upsample(x) | |
| x = x + skip1 | |
| x = self.block5(x) | |
| x = self.upsample(x) | |
| mask = self.block_sigmoid(x) | |
| x = (1 + mask) * trunk_branch | |
| x = x + trunk_branch | |
| out_last = self.block6(x) | |
| return out_last | |
| def conv_block(in_planes, out_planes, bigblock, convxd, norm, pooling, fs=3, stride=1, relu=nn.ReLU): | |
| if bigblock: | |
| block = nn.Sequential( | |
| convxd(in_planes, out_planes, fs, 1, int(fs/2)), | |
| relu(True), | |
| norm(out_planes), | |
| convxd(out_planes, out_planes, fs, 1, int(fs/2)), | |
| relu(True), | |
| norm(out_planes), | |
| convxd(out_planes, out_planes, fs, stride, int(fs/2)), | |
| relu(True), | |
| norm(out_planes), | |
| pooling(2, 2) | |
| ) | |
| else: | |
| block = nn.Sequential( | |
| convxd(in_planes, out_planes, fs, 1, int(fs/2)), | |
| relu(True), | |
| norm(out_planes), | |
| convxd(out_planes, out_planes, fs, stride, int(fs/2)), | |
| relu(True), | |
| norm(out_planes), | |
| pooling(2, 2) | |
| ) | |
| return block | |