File size: 2,451 Bytes
5c783e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#-*- coding: utf-8 -*-
import torch
import torch.nn as nn


BN_MOMENTUM = 0.1


def point_wise_block(inplanes, outplanes):
    return nn.Sequential(
        nn.Conv2d(in_channels=inplanes, out_channels=outplanes, kernel_size=1, padding=0, stride=1, bias=False),
        nn.BatchNorm2d(outplanes, momentum=BN_MOMENTUM),
        nn.ReLU(inplace=True),
    )


def conv_block(inplanes, outplanes, kernel_size, stride=1, padding=0):
    return nn.Sequential(
        nn.Conv2d(in_channels=inplanes, out_channels=outplanes, kernel_size=kernel_size, padding=padding, stride=stride, bias=False),
        nn.BatchNorm2d(outplanes, momentum=BN_MOMENTUM),
        nn.ReLU(inplace=True)
    )


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class InceptionBlock(nn.Module):
    def __init__(self, inplanes, outplanes, stride=1, pool_size=3):
        self.inplanes = inplanes
        self.outplanes = outplanes
        self.stride = stride
        self.pool_size = pool_size
        super(InceptionBlock, self).__init__()

        self.pw_block = point_wise_block(self.inplanes, self.outplanes//4)
        self.mp_layer = nn.MaxPool2d(kernel_size=self.pool_size, stride=stride, padding=1)
        self.conv3_block = conv_block(self.outplanes//4, self.outplanes//4, kernel_size=3, stride=1, padding=1)
        self.conv5_block = conv_block(self.outplanes//4, self.outplanes//4, kernel_size=5, stride=1, padding=2)
        
    def forward(self, x):
        x1 = self.pw_block(x)

        x2 = self.pw_block(x)
        x2 = self.conv3_block(x2)

        x3 = self.pw_block(x)
        x3 = self.conv5_block(x3)

        x4 = self.mp_layer(x)
        x4 = self.pw_block(x4)

        x = torch.cat((x1, x2, x3, x4), dim=1)
        return x


class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)