File size: 4,271 Bytes
96138e1
cca733a
96138e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

# Paste the full DANet architecture code here... (same as before)
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import models

class PositionAttentionModule(nn.Module):
    def __init__(self, in_dim):
        super(PositionAttentionModule, self).__init__(); self.chanel_in = in_dim
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1)); self.softmax = nn.Softmax(dim=-1)
    def forward(self, x):
        B, C, H, W = x.size()
        proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(B, -1, H * W)
        energy = torch.bmm(proj_query, proj_key)
        attention_map = self.softmax(energy)
        proj_value = self.value_conv(x).view(B, -1, H * W)
        out = torch.bmm(proj_value, attention_map.permute(0, 2, 1))
        out = out.view(B, C, H, W); out = self.gamma * out + x
        return out

class ChannelAttentionModule(nn.Module):
    def __init__(self):
        super(ChannelAttentionModule, self).__init__()
        self.beta = nn.Parameter(torch.zeros(1)); self.softmax = nn.Softmax(dim=-1)
    def forward(self, x):
        B, C, H, W = x.size()
        proj_query = x.view(B, C, -1); proj_key = x.view(B, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key); attention_map = self.softmax(energy)
        proj_value = x.view(B, C, -1); out = torch.bmm(attention_map, proj_value)
        out = out.view(B, C, H, W); out = self.beta * out + x
        return out

class DANetHead(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DANetHead, self).__init__()
        inter_channels = in_channels // 4
        self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
        self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
        self.sa = PositionAttentionModule(inter_channels); self.sc = ChannelAttentionModule()
        self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
        self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
        self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, out_channels, 1))
    def forward(self, x):
        feat_sa = self.conv5a(x); sa_feat = self.sa(feat_sa); sa_conv = self.conv51(sa_feat)
        feat_sc = self.conv5c(x); sc_feat = self.sc(feat_sc); sc_conv = self.conv52(sc_feat)
        feat_sum = sa_conv + sc_conv; s_out = self.conv8(feat_sum)
        return s_out

class DANet(nn.Module):
    def __init__(self, num_classes=2, backbone='resnet50', pretrained_base=False, aux=False):
        super(DANet, self).__init__(); self.aux = aux
        weights = None
        resnet = models.resnet50(weights=weights); backbone_out_channels = 2048
        resnet.layer3[0].conv2.stride = (1, 1); resnet.layer3[0].downsample[0].stride = (1, 1)
        resnet.layer4[0].conv2.stride = (1, 1); resnet.layer4[0].downsample[0].stride = (1, 1)
        for i in range(len(resnet.layer4)): resnet.layer4[i].conv2.dilation = (2, 2); resnet.layer4[i].conv2.padding = (2, 2)
        self.conv1=resnet.conv1; self.bn1=resnet.bn1; self.relu=resnet.relu; self.maxpool=resnet.maxpool
        self.layer1=resnet.layer1; self.layer2=resnet.layer2; self.layer3=resnet.layer3; self.layer4=resnet.layer4
        self.head = DANetHead(backbone_out_channels, num_classes)
    def forward(self, x):
        imsize = x.size()[2:]; x = self.conv1(x); x = self.bn1(x); x = self.relu(x); x = self.maxpool(x)
        x = self.layer1(x); x = self.layer2(x); c3 = self.layer3(x); c4 = self.layer4(c3)
        main_out = self.head(c4); main_out = F.interpolate(main_out, size=imsize, mode='bilinear', align_corners=True)
        return main_out