|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from torchvision import models |
|
|
|
|
|
class PositionAttentionModule(nn.Module): |
|
|
def __init__(self, in_dim): |
|
|
super(PositionAttentionModule, self).__init__(); self.chanel_in = in_dim |
|
|
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) |
|
|
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) |
|
|
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) |
|
|
self.gamma = nn.Parameter(torch.zeros(1)); self.softmax = nn.Softmax(dim=-1) |
|
|
def forward(self, x): |
|
|
B, C, H, W = x.size() |
|
|
proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1) |
|
|
proj_key = self.key_conv(x).view(B, -1, H * W) |
|
|
energy = torch.bmm(proj_query, proj_key) |
|
|
attention_map = self.softmax(energy) |
|
|
proj_value = self.value_conv(x).view(B, -1, H * W) |
|
|
out = torch.bmm(proj_value, attention_map.permute(0, 2, 1)) |
|
|
out = out.view(B, C, H, W); out = self.gamma * out + x |
|
|
return out |
|
|
|
|
|
class ChannelAttentionModule(nn.Module): |
|
|
def __init__(self): |
|
|
super(ChannelAttentionModule, self).__init__() |
|
|
self.beta = nn.Parameter(torch.zeros(1)); self.softmax = nn.Softmax(dim=-1) |
|
|
def forward(self, x): |
|
|
B, C, H, W = x.size() |
|
|
proj_query = x.view(B, C, -1); proj_key = x.view(B, C, -1).permute(0, 2, 1) |
|
|
energy = torch.bmm(proj_query, proj_key); attention_map = self.softmax(energy) |
|
|
proj_value = x.view(B, C, -1); out = torch.bmm(attention_map, proj_value) |
|
|
out = out.view(B, C, H, W); out = self.beta * out + x |
|
|
return out |
|
|
|
|
|
class DANetHead(nn.Module): |
|
|
def __init__(self, in_channels, out_channels): |
|
|
super(DANetHead, self).__init__() |
|
|
inter_channels = in_channels // 4 |
|
|
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) |
|
|
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) |
|
|
self.sa = PositionAttentionModule(inter_channels); self.sc = ChannelAttentionModule() |
|
|
self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) |
|
|
self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) |
|
|
self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, out_channels, 1)) |
|
|
def forward(self, x): |
|
|
feat_sa = self.conv5a(x); sa_feat = self.sa(feat_sa); sa_conv = self.conv51(sa_feat) |
|
|
feat_sc = self.conv5c(x); sc_feat = self.sc(feat_sc); sc_conv = self.conv52(sc_feat) |
|
|
feat_sum = sa_conv + sc_conv; s_out = self.conv8(feat_sum) |
|
|
return s_out |
|
|
|
|
|
class DANet(nn.Module): |
|
|
def __init__(self, num_classes=2, backbone='resnet50', pretrained_base=False, aux=False): |
|
|
super(DANet, self).__init__(); self.aux = aux |
|
|
weights = None |
|
|
resnet = models.resnet50(weights=weights); backbone_out_channels = 2048 |
|
|
resnet.layer3[0].conv2.stride = (1, 1); resnet.layer3[0].downsample[0].stride = (1, 1) |
|
|
resnet.layer4[0].conv2.stride = (1, 1); resnet.layer4[0].downsample[0].stride = (1, 1) |
|
|
for i in range(len(resnet.layer4)): resnet.layer4[i].conv2.dilation = (2, 2); resnet.layer4[i].conv2.padding = (2, 2) |
|
|
self.conv1=resnet.conv1; self.bn1=resnet.bn1; self.relu=resnet.relu; self.maxpool=resnet.maxpool |
|
|
self.layer1=resnet.layer1; self.layer2=resnet.layer2; self.layer3=resnet.layer3; self.layer4=resnet.layer4 |
|
|
self.head = DANetHead(backbone_out_channels, num_classes) |
|
|
def forward(self, x): |
|
|
imsize = x.size()[2:]; x = self.conv1(x); x = self.bn1(x); x = self.relu(x); x = self.maxpool(x) |
|
|
x = self.layer1(x); x = self.layer2(x); c3 = self.layer3(x); c4 = self.layer4(c3) |
|
|
main_out = self.head(c4); main_out = F.interpolate(main_out, size=imsize, mode='bilinear', align_corners=True) |
|
|
return main_out |
|
|
|