File size: 3,956 Bytes
8d7921b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


class MergeAndConv(nn.Module):

    def __init__(self, ic, oc, inner=32):
        super().__init__()

        self.conv1 = nn.Conv2d(ic, inner, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(inner)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(inner, oc, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv2(self.bn(self.relu(self.conv1(x))))
        x = torch.sigmoid(x)
        return x


class SideClassifer(nn.Module):
    def __init__(self, ic, n_class=1, M=2, kernel_size=1):
        super().__init__()

        sides = []
        for i in range(M):
            sides.append(nn.Conv2d(ic, n_class, kernel_size=kernel_size))

        self.sides = nn.ModuleList(sides)

    def forward(self, x):
        return [fn(x) for fn in self.sides]


class UpsampleSKConv(nn.Module):
    """docstring for UpsampleSKConvPlus"""

    def __init__(self, ic, oc, reduce=4):
        super(UpsampleSKConv, self).__init__()

        self.relu = nn.ReLU(inplace=True)
        self.prev = nn.Conv2d(ic, ic // reduce, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(ic // reduce)

        self.next = nn.Conv2d(ic // reduce, oc, kernel_size=1, stride=1)
        self.bn2 = nn.BatchNorm2d(oc)

        self.sk = SKSPP(ic // reduce, ic // reduce, M=4)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2)

        x = self.bn(self.relu(self.prev(x)))

        x = self.sk(x)

        x = self.bn2(self.relu(self.next(x)))

        return x


class SKSPP(nn.Module):
    def __init__(self, features, WH, M=2, G=1, r=16, stride=1, L=32):
        """ Constructor
        Args:
            features: input channel dimensionality.
            WH: input spatial dimensionality, used for GAP kernel size.
            M: the number of branchs.
            G: num of convolution groups.
            r: the radio for compute d, the length of z.
            stride: stride, default 1.
            L: the minimum dim of the vector z in paper, default 32.
        """
        super(SKSPP, self).__init__()
        d = max(int(features / r), L)
        self.M = M  # original
        self.features = features
        self.convs = nn.ModuleList([])

        # 1,3,5,7 padding:[0,1,2,3]
        for i in range(1, M):
            self.convs.append(nn.Sequential(
                nn.Conv2d(features, features, kernel_size=1 + i * 2, dilation=1 + i * 2, stride=stride,
                          padding=((1 + i * 2) * (i * 2) + 1) // 2, groups=G),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=False)
            ))
        # self.gap = nn.AvgPool2d(int(WH/stride))
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList([])
        for i in range(M):
            self.fcs.append(
                nn.Linear(d, features)
            )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):

        feas = torch.unsqueeze(x, dim=1)

        # F->conv1x1->conv3x3->conv5x5->conv7x7

        for i, conv in enumerate(self.convs):
            x = conv(x)
            # if i == 0:
            #     feas = fea
            # else:
            feas = torch.cat([feas, torch.unsqueeze(x, dim=1)], dim=1)

        fea_U = torch.sum(feas, dim=1)
        fea_s = fea_U.mean(-1).mean(-1)
        fea_z = self.fc(fea_s)

        for i, fc in enumerate(self.fcs):
            vector = fc(fea_z).unsqueeze_(dim=1)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector], dim=1)

        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v