Inmental's picture
Upload folder using huggingface_hub
4c62147 verified
"""
This EfficientNet implementation comes from:
Author: lukemelas (github username)
Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
"""
import torch
import torch.nn as nn
from .model import EfficientNet
from .utils import round_filters, get_same_padding_conv2d
# for EfficientNet
class EfficientBackbone(EfficientNet):
def __init__(self, blocks_args=None, global_params=None):
super(EfficientBackbone, self).__init__(blocks_args, global_params)
self.enc_channels = [16, 24, 40, 112, 1280]
# ------------------------------------------------------------
# delete the useless layers
# ------------------------------------------------------------
del self._conv_stem
del self._bn0
# ------------------------------------------------------------
# ------------------------------------------------------------
# parameters for the input layers
# ------------------------------------------------------------
bn_mom = 1 - self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon
in_channels = 4
out_channels = round_filters(32, self._global_params)
out_channels = int(out_channels / 2)
# ------------------------------------------------------------
# ------------------------------------------------------------
# define the input layers
# ------------------------------------------------------------
image_size = global_params.image_size
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._conv_fg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
self._bn_fg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
self._conv_bg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
self._bn_bg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
# ------------------------------------------------------------
def forward(self, xfg, xbg):
xfg = self._swish(self._bn_fg(self._conv_fg(xfg)))
xbg = self._swish(self._bn_bg(self._conv_bg(xbg)))
x = torch.cat((xfg, xbg), dim=1)
block_outputs = []
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
block_outputs.append(x)
# Head
x = self._swish(self._bn1(self._conv_head(x)))
return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x
# for EfficientNet
class EfficientBackboneCommon(EfficientNet):
def __init__(self, blocks_args=None, global_params=None):
super(EfficientBackboneCommon, self).__init__(blocks_args, global_params)
self.enc_channels = [16, 24, 40, 112, 1280]
# ------------------------------------------------------------
# delete the useless layers
# ------------------------------------------------------------
del self._conv_stem
del self._bn0
# ------------------------------------------------------------
# ------------------------------------------------------------
# parameters for the input layers
# ------------------------------------------------------------
bn_mom = 1 - self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon
in_channels = 3
out_channels = round_filters(32, self._global_params)
# ------------------------------------------------------------
# ------------------------------------------------------------
# define the input layers
# ------------------------------------------------------------
image_size = global_params.image_size
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._conv = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
self._bn = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
# ------------------------------------------------------------
def forward(self, x):
x = self._swish(self._bn(self._conv(x)))
block_outputs = []
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
block_outputs.append(x)
# Head
x = self._swish(self._bn1(self._conv_head(x)))
return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x