File size: 3,357 Bytes
3155589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True)
        self.W_x = nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True)
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class AttentionUNet(nn.Module):
    def __init__(self, img_ch=1, output_ch=4):
        super().__init__()
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.downs = nn.ModuleList([
            DoubleConv(img_ch, 64),
            DoubleConv(64, 128),
            DoubleConv(128, 256),
            DoubleConv(256, 512)
        ])

        self.bottleneck = DoubleConv(512, 1024)

        self.ups = nn.ModuleList([
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        ])
        
        self.attention_gates = nn.ModuleList([
            AttentionGate(F_g=512, F_l=512, F_int=256),
            AttentionGate(F_g=256, F_l=256, F_int=128),
            AttentionGate(F_g=128, F_l=128, F_int=64),
            AttentionGate(F_g=64, F_l=64, F_int=32)
        ])
        
        self.up_convs = nn.ModuleList([
            DoubleConv(1024, 512),
            DoubleConv(512, 256),
            DoubleConv(256, 128),
            DoubleConv(128, 64)
        ])

        self.final_conv = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        e1 = self.downs[0](x)
        e2 = self.downs[1](self.Maxpool(e1))
        e3 = self.downs[2](self.Maxpool(e2))
        e4 = self.downs[3](self.Maxpool(e3))

        b = self.bottleneck(self.Maxpool(e4))

        d4 = self.ups[0](b)
        x4 = self.attention_gates[0](g=d4, x=e4)
        d4 = self.up_convs[0](torch.cat((x4, d4), dim=1))

        d3 = self.ups[1](d4)
        x3 = self.attention_gates[1](g=d3, x=e3)
        d3 = self.up_convs[1](torch.cat((x3, d3), dim=1))

        d2 = self.ups[2](d3)
        x2 = self.attention_gates[2](g=d2, x=e2)
        d2 = self.up_convs[2](torch.cat((x2, d2), dim=1))

        d1 = self.ups[3](d2)
        x1 = self.attention_gates[3](g=d1, x=e1)
        d1 = self.up_convs[3](torch.cat((x1, d1), dim=1))

        return self.final_conv(d1)