|
|
""" |
|
|
This EfficientNet implementation comes from: |
|
|
Author: lukemelas (github username) |
|
|
Github repo: https://github.com/lukemelas/EfficientNet-PyTorch |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .model import EfficientNet |
|
|
from .utils import round_filters, get_same_padding_conv2d |
|
|
|
|
|
|
|
|
|
|
|
class EfficientBackbone(EfficientNet): |
|
|
def __init__(self, blocks_args=None, global_params=None): |
|
|
super(EfficientBackbone, self).__init__(blocks_args, global_params) |
|
|
|
|
|
self.enc_channels = [16, 24, 40, 112, 1280] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del self._conv_stem |
|
|
del self._bn0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bn_mom = 1 - self._global_params.batch_norm_momentum |
|
|
bn_eps = self._global_params.batch_norm_epsilon |
|
|
|
|
|
in_channels = 4 |
|
|
out_channels = round_filters(32, self._global_params) |
|
|
out_channels = int(out_channels / 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_size = global_params.image_size |
|
|
Conv2d = get_same_padding_conv2d(image_size=image_size) |
|
|
self._conv_fg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) |
|
|
self._bn_fg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
|
|
|
self._conv_bg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) |
|
|
self._bn_bg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
|
|
|
|
|
|
def forward(self, xfg, xbg): |
|
|
xfg = self._swish(self._bn_fg(self._conv_fg(xfg))) |
|
|
xbg = self._swish(self._bn_bg(self._conv_bg(xbg))) |
|
|
|
|
|
x = torch.cat((xfg, xbg), dim=1) |
|
|
|
|
|
block_outputs = [] |
|
|
for idx, block in enumerate(self._blocks): |
|
|
drop_connect_rate = self._global_params.drop_connect_rate |
|
|
drop_connect_rate *= float(idx) / len(self._blocks) |
|
|
x = block(x, drop_connect_rate=drop_connect_rate) |
|
|
block_outputs.append(x) |
|
|
|
|
|
|
|
|
x = self._swish(self._bn1(self._conv_head(x))) |
|
|
|
|
|
return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x |
|
|
|
|
|
|
|
|
|
|
|
class EfficientBackboneCommon(EfficientNet): |
|
|
def __init__(self, blocks_args=None, global_params=None): |
|
|
super(EfficientBackboneCommon, self).__init__(blocks_args, global_params) |
|
|
|
|
|
self.enc_channels = [16, 24, 40, 112, 1280] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del self._conv_stem |
|
|
del self._bn0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bn_mom = 1 - self._global_params.batch_norm_momentum |
|
|
bn_eps = self._global_params.batch_norm_epsilon |
|
|
|
|
|
in_channels = 3 |
|
|
out_channels = round_filters(32, self._global_params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_size = global_params.image_size |
|
|
Conv2d = get_same_padding_conv2d(image_size=image_size) |
|
|
self._conv = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) |
|
|
self._bn = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
x = self._swish(self._bn(self._conv(x))) |
|
|
|
|
|
block_outputs = [] |
|
|
for idx, block in enumerate(self._blocks): |
|
|
drop_connect_rate = self._global_params.drop_connect_rate |
|
|
drop_connect_rate *= float(idx) / len(self._blocks) |
|
|
x = block(x, drop_connect_rate=drop_connect_rate) |
|
|
block_outputs.append(x) |
|
|
|
|
|
|
|
|
x = self._swish(self._bn1(self._conv_head(x))) |
|
|
|
|
|
return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x |
|
|
|