File size: 4,345 Bytes
beb7843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch.nn as nn


def get_activation(act_type=None):
    if act_type == 'relu':
        return nn.ReLU(inplace=True)
    elif act_type == 'lrelu':
        return nn.LeakyReLU(0.1, inplace=True)
    elif act_type == 'mish':
        return nn.Mish(inplace=True)
    elif act_type == 'silu':
        return nn.SiLU(inplace=True)


# 2D Conv
def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
    return conv


def get_norm2d(norm_type, dim):
    if norm_type == 'BN':
        return nn.BatchNorm2d(dim)
    elif norm_type == 'IN':
        return nn.InstanceNorm2d(dim)


class Conv2d(nn.Module):
    def __init__(self, 
                 c1,                   # in channels
                 c2,                   # out channels 
                 k=1,                  # kernel size 
                 p=0,                  # padding
                 s=1,                  # padding
                 d=1,                  # dilation
                 g=1,
                 act_type='',          # activation
                 norm_type='',         # normalization
                 depthwise=False):
        super(Conv2d, self).__init__()
        convs = []
        add_bias = False if norm_type else True
        if depthwise:
            assert c1 == c2, "In depthwise conv, the in_dim (c1) should be equal to out_dim (c2)."
            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
            # depthwise conv
            if norm_type:
                convs.append(get_norm2d(norm_type, c2))
            if act_type:
                convs.append(get_activation(act_type))
            # pointwise conv
            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
            if norm_type:
                convs.append(get_norm2d(norm_type, c2))
            if act_type:
                convs.append(get_activation(act_type))

        else:
            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=g, bias=add_bias))
            if norm_type:
                convs.append(get_norm2d(norm_type, c2))
            if act_type:
                convs.append(get_activation(act_type))
            
        self.convs = nn.Sequential(*convs)


    def forward(self, x):
        return self.convs(x)


# 3D Conv
def get_conv3d(c1, c2, k, p, s, d, g, bias=False):
    conv = nn.Conv3d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
    return conv


def get_norm3d(norm_type, dim):
    if norm_type == 'BN':
        return nn.BatchNorm3d(dim)
    elif norm_type == 'IN':
        return nn.InstanceNorm3d(dim)


class Conv3d(nn.Module):
    def __init__(self, 
                 c1,                   # in channels
                 c2,                   # out channels 
                 k=1,                  # kernel size 
                 p=0,                  # padding
                 s=1,                  # padding
                 d=1,                  # dilation
                 g=1,
                 act_type='',          # activation
                 norm_type='',         # normalization
                 depthwise=False):
        super(Conv3d, self).__init__()
        convs = []
        add_bias = False if norm_type else True
        if depthwise:
            assert c1 == c2, "In depthwise conv, the in_dim (c1) should be equal to out_dim (c2)."
            convs.append(get_conv3d(c1, c2, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
            # depthwise conv
            if norm_type:
                convs.append(get_norm3d(norm_type, c2))
            if act_type:
                convs.append(get_activation(act_type))
            # pointwise conv
            convs.append(get_conv3d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
            if norm_type:
                convs.append(get_norm3d(norm_type, c2))
            if act_type:
                convs.append(get_activation(act_type))

        else:
            convs.append(get_conv3d(c1, c2, k=k, p=p, s=s, d=d, g=g, bias=add_bias))
            if norm_type:
                convs.append(get_norm3d(norm_type, c2))
            if act_type:
                convs.append(get_activation(act_type))
            
        self.convs = nn.Sequential(*convs)


    def forward(self, x):
        return self.convs(x)