File size: 2,507 Bytes
46b9840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Res_Net(nn.Module):
    def __init__(self,input_cha):
        super(Res_Net,self).__init__()
        self.conv1 = nn.Conv2d(input_cha,input_cha,3,padding=1)
        self.conv2 = nn.Conv2d(input_cha,input_cha,5,padding=2)
        self.conv3 = nn.Conv2d(input_cha,input_cha,7,padding=3)

        self.cbamBlock = CBAMBlock(input_cha) 

        self.bn1 = nn.BatchNorm2d(input_cha)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.LeakyReLU()

    def forward(self,x):
        init_x = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu2(out)

        out = self.conv1(out)
        out = self.bn1(out)
        out += init_x
        out = self.relu2(out)

        return out

class CBAMBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CBAMBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.channel_excitation = nn.Sequential(nn.Linear(channel,int(channel//reduction),bias=False), 
                                                nn.ReLU(inplace=True),
                                                nn.Linear(int(channel//reduction),channel,bias=False),
                                                )
        self.sigmoid = nn.Sigmoid()

        self.spatial_excitation = nn.Sequential(nn.Conv2d(2, 1, kernel_size=7,
                                                 stride=1, padding=3, bias=False),
                                               )


    def forward(self, x):
        bahs, chs, _, _ = x.size()     #16 16 24 42

        # Returns a new tensor with the same data as the self tensor but of a different size.
        chn_avg = self.avg_pool(x).view(bahs, chs)                       
        chn_avg = self.channel_excitation(chn_avg).view(bahs, chs, 1, 1)  
        chn_max = self.max_pool(x).view(bahs, chs)
        chn_max = self.channel_excitation(chn_max).view(bahs, chs, 1, 1)
        chn_add=chn_avg+chn_max
        chn_add=self.sigmoid(chn_add)

        chn_cbam = torch.mul(x, chn_add)

        avg_out = torch.mean(chn_cbam, dim=1, keepdim=True)
        max_out, _ = torch.max(chn_cbam, dim=1, keepdim=True)
        cat = torch.cat([avg_out, max_out], dim=1)

        spa_add = self.spatial_excitation(cat)
        spa_add = self.sigmoid(spa_add)
        spa_cbam = torch.mul(chn_cbam, spa_add)

        return spa_cbam