File size: 4,873 Bytes
4c62147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
This EfficientNet implementation comes from:
    Author: lukemelas (github username)
    Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
"""

import torch
import torch.nn as nn

from .model import EfficientNet
from .utils import round_filters, get_same_padding_conv2d


# for EfficientNet
class EfficientBackbone(EfficientNet):
    def __init__(self, blocks_args=None, global_params=None):
        super(EfficientBackbone, self).__init__(blocks_args, global_params)

        self.enc_channels = [16, 24, 40, 112, 1280]

        # ------------------------------------------------------------
        # delete the useless layers
        # ------------------------------------------------------------
        del self._conv_stem
        del self._bn0
        # ------------------------------------------------------------

        # ------------------------------------------------------------
        # parameters for the input layers
        # ------------------------------------------------------------
        bn_mom = 1 - self._global_params.batch_norm_momentum
        bn_eps = self._global_params.batch_norm_epsilon

        in_channels = 4
        out_channels = round_filters(32, self._global_params)
        out_channels = int(out_channels / 2)
        # ------------------------------------------------------------

        # ------------------------------------------------------------
        # define the input layers
        # ------------------------------------------------------------
        image_size = global_params.image_size
        Conv2d = get_same_padding_conv2d(image_size=image_size)
        self._conv_fg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
        self._bn_fg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)

        self._conv_bg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
        self._bn_bg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
        # ------------------------------------------------------------

    def forward(self, xfg, xbg):
        xfg = self._swish(self._bn_fg(self._conv_fg(xfg)))
        xbg = self._swish(self._bn_bg(self._conv_bg(xbg)))

        x = torch.cat((xfg, xbg), dim=1)

        block_outputs = []
        for idx, block in enumerate(self._blocks):
            drop_connect_rate = self._global_params.drop_connect_rate
            drop_connect_rate *= float(idx) / len(self._blocks)
            x = block(x, drop_connect_rate=drop_connect_rate)
            block_outputs.append(x)

        # Head
        x = self._swish(self._bn1(self._conv_head(x)))
        
        return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x


# for EfficientNet
class EfficientBackboneCommon(EfficientNet):
    def __init__(self, blocks_args=None, global_params=None):
        super(EfficientBackboneCommon, self).__init__(blocks_args, global_params)

        self.enc_channels = [16, 24, 40, 112, 1280]

        # ------------------------------------------------------------
        # delete the useless layers
        # ------------------------------------------------------------
        del self._conv_stem
        del self._bn0
        # ------------------------------------------------------------

        # ------------------------------------------------------------
        # parameters for the input layers
        # ------------------------------------------------------------
        bn_mom = 1 - self._global_params.batch_norm_momentum
        bn_eps = self._global_params.batch_norm_epsilon

        in_channels = 3
        out_channels = round_filters(32, self._global_params)
        # ------------------------------------------------------------

        # ------------------------------------------------------------
        # define the input layers
        # ------------------------------------------------------------
        image_size = global_params.image_size
        Conv2d = get_same_padding_conv2d(image_size=image_size)
        self._conv = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
        self._bn = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
        # ------------------------------------------------------------

    def forward(self, x):
        x = self._swish(self._bn(self._conv(x)))

        block_outputs = []
        for idx, block in enumerate(self._blocks):
            drop_connect_rate = self._global_params.drop_connect_rate
            drop_connect_rate *= float(idx) / len(self._blocks)
            x = block(x, drop_connect_rate=drop_connect_rate)
            block_outputs.append(x)

        # Head
        x = self._swish(self._bn1(self._conv_head(x)))
        
        return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x