|
|
"""
|
|
|
Original author: lukemelas (github username)
|
|
|
Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
|
|
With adjustments and added comments by workingcoder (github username).
|
|
|
|
|
|
Reimplemented: Min Seok Lee and Wooseok Shin
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
from torch.nn import functional as F
|
|
|
from util.effi_utils import (
|
|
|
get_model_shape,
|
|
|
round_filters,
|
|
|
round_repeats,
|
|
|
drop_connect,
|
|
|
get_same_padding_conv2d,
|
|
|
get_model_params,
|
|
|
efficientnet_params,
|
|
|
load_pretrained_weights,
|
|
|
Swish,
|
|
|
MemoryEfficientSwish,
|
|
|
calculate_output_image_size
|
|
|
)
|
|
|
from modules.att_modules import Frequency_Edge_Module
|
|
|
from config import getConfig
|
|
|
|
|
|
cfg = getConfig()
|
|
|
|
|
|
VALID_MODELS = (
|
|
|
'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
|
|
|
'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
|
|
|
'efficientnet-b8',
|
|
|
|
|
|
|
|
|
'efficientnet-l2'
|
|
|
)
|
|
|
|
|
|
|
|
|
class MBConvBlock(nn.Module):
|
|
|
"""Mobile Inverted Residual Bottleneck Block.
|
|
|
|
|
|
Args:
|
|
|
block_args (namedtuple): BlockArgs, defined in utils.py.
|
|
|
global_params (namedtuple): GlobalParam, defined in utils.py.
|
|
|
image_size (tuple or list): [image_height, image_width].
|
|
|
|
|
|
References:
|
|
|
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
|
|
|
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
|
|
|
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, block_args, global_params, image_size=None):
|
|
|
super().__init__()
|
|
|
self._block_args = block_args
|
|
|
self._bn_mom = 1 - global_params.batch_norm_momentum
|
|
|
self._bn_eps = global_params.batch_norm_epsilon
|
|
|
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
|
|
self.id_skip = block_args.id_skip
|
|
|
|
|
|
|
|
|
inp = self._block_args.input_filters
|
|
|
oup = self._block_args.input_filters * self._block_args.expand_ratio
|
|
|
if self._block_args.expand_ratio != 1:
|
|
|
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
|
|
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
|
|
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
|
|
|
|
|
|
|
|
|
|
|
k = self._block_args.kernel_size
|
|
|
s = self._block_args.stride
|
|
|
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
|
|
self._depthwise_conv = Conv2d(
|
|
|
in_channels=oup, out_channels=oup, groups=oup,
|
|
|
kernel_size=k, stride=s, bias=False)
|
|
|
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
|
|
image_size = calculate_output_image_size(image_size, s)
|
|
|
|
|
|
|
|
|
if self.has_se:
|
|
|
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
|
|
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
|
|
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
|
|
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
|
|
|
|
|
|
|
|
final_oup = self._block_args.output_filters
|
|
|
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
|
|
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
|
|
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
|
|
self._swish = MemoryEfficientSwish()
|
|
|
|
|
|
def forward(self, inputs, drop_connect_rate=None):
|
|
|
"""MBConvBlock's forward function.
|
|
|
|
|
|
Args:
|
|
|
inputs (tensor): Input tensor.
|
|
|
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
|
|
|
|
|
|
Returns:
|
|
|
Output of this block after processing.
|
|
|
"""
|
|
|
|
|
|
|
|
|
x = inputs
|
|
|
if self._block_args.expand_ratio != 1:
|
|
|
x = self._expand_conv(inputs)
|
|
|
x = self._bn0(x)
|
|
|
x = self._swish(x)
|
|
|
|
|
|
x = self._depthwise_conv(x)
|
|
|
x = self._bn1(x)
|
|
|
x = self._swish(x)
|
|
|
|
|
|
|
|
|
if self.has_se:
|
|
|
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
|
|
x_squeezed = self._se_reduce(x_squeezed)
|
|
|
x_squeezed = self._swish(x_squeezed)
|
|
|
x_squeezed = self._se_expand(x_squeezed)
|
|
|
x = torch.sigmoid(x_squeezed) * x
|
|
|
|
|
|
|
|
|
x = self._project_conv(x)
|
|
|
x = self._bn2(x)
|
|
|
|
|
|
|
|
|
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
|
|
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
|
|
|
|
|
if drop_connect_rate:
|
|
|
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
|
|
x = x + inputs
|
|
|
return x
|
|
|
|
|
|
def set_swish(self, memory_efficient=True):
|
|
|
"""Sets swish function as memory efficient (for training) or standard (for export).
|
|
|
|
|
|
Args:
|
|
|
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
|
|
"""
|
|
|
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
|
|
|
|
|
|
|
|
class EfficientNet(nn.Module):
|
|
|
def __init__(self, blocks_args=None, global_params=None):
|
|
|
super().__init__()
|
|
|
assert isinstance(blocks_args, list), 'blocks_args should be a list'
|
|
|
assert len(blocks_args) > 0, 'block args must be greater than 0'
|
|
|
self._global_params = global_params
|
|
|
self._blocks_args = blocks_args
|
|
|
self.block_idx, self.channels = get_model_shape()
|
|
|
self.Frequency_Edge_Module1 = Frequency_Edge_Module(radius=cfg.frequency_radius,
|
|
|
channel=self.channels[0])
|
|
|
|
|
|
bn_mom = 1 - self._global_params.batch_norm_momentum
|
|
|
bn_eps = self._global_params.batch_norm_epsilon
|
|
|
|
|
|
|
|
|
image_size = global_params.image_size
|
|
|
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
|
|
|
|
|
|
|
|
in_channels = 3
|
|
|
out_channels = round_filters(32, self._global_params)
|
|
|
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
|
|
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
|
|
image_size = calculate_output_image_size(image_size, 2)
|
|
|
|
|
|
|
|
|
self._blocks = nn.ModuleList([])
|
|
|
for block_args in self._blocks_args:
|
|
|
|
|
|
|
|
|
block_args = block_args._replace(
|
|
|
input_filters=round_filters(block_args.input_filters, self._global_params),
|
|
|
output_filters=round_filters(block_args.output_filters, self._global_params),
|
|
|
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
|
|
|
)
|
|
|
|
|
|
|
|
|
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
|
|
image_size = calculate_output_image_size(image_size, block_args.stride)
|
|
|
if block_args.num_repeat > 1:
|
|
|
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
|
|
|
for _ in range(block_args.num_repeat - 1):
|
|
|
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
|
|
|
|
|
|
|
|
self._swish = MemoryEfficientSwish()
|
|
|
|
|
|
def set_swish(self, memory_efficient=True):
|
|
|
"""Sets swish function as memory efficient (for training) or standard (for export).
|
|
|
|
|
|
Args:
|
|
|
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
|
|
|
|
|
"""
|
|
|
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
|
|
for block in self._blocks:
|
|
|
block.set_swish(memory_efficient)
|
|
|
|
|
|
def extract_endpoints(self, inputs):
|
|
|
endpoints = dict()
|
|
|
|
|
|
|
|
|
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
|
|
prev_x = x
|
|
|
|
|
|
|
|
|
for idx, block in enumerate(self._blocks):
|
|
|
drop_connect_rate = self._global_params.drop_connect_rate
|
|
|
if drop_connect_rate:
|
|
|
drop_connect_rate *= float(idx) / len(self._blocks)
|
|
|
x = block(x, drop_connect_rate=drop_connect_rate)
|
|
|
if prev_x.size(2) > x.size(2):
|
|
|
endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
|
|
|
prev_x = x
|
|
|
|
|
|
|
|
|
x = self._swish(self._bn1(self._conv_head(x)))
|
|
|
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
|
|
|
|
|
|
return endpoints
|
|
|
|
|
|
|
|
|
def initial_conv(self, inputs):
|
|
|
|
|
|
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
def get_blocks(self, x, H, W):
|
|
|
|
|
|
for idx, block in enumerate(self._blocks):
|
|
|
drop_connect_rate = self._global_params.drop_connect_rate
|
|
|
if drop_connect_rate:
|
|
|
drop_connect_rate *= float(idx) / len(self._blocks)
|
|
|
|
|
|
x = block(x, drop_connect_rate=drop_connect_rate)
|
|
|
|
|
|
if idx == self.block_idx[0]:
|
|
|
x, edge = self.Frequency_Edge_Module1(x)
|
|
|
edge = F.interpolate(edge, size=(H, W), mode='bilinear')
|
|
|
x1 = x.clone()
|
|
|
if idx == self.block_idx[1]:
|
|
|
x2 = x.clone()
|
|
|
if idx == self.block_idx[2]:
|
|
|
x3 = x.clone()
|
|
|
if idx == self.block_idx[3]:
|
|
|
x4 = x.clone()
|
|
|
|
|
|
return (x1, x2, x3, x4), edge
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
def from_name(cls, model_name, in_channels=3, **override_params):
|
|
|
"""create an efficientnet model according to name.
|
|
|
|
|
|
Args:
|
|
|
model_name (str): Name for efficientnet.
|
|
|
in_channels (int): Input data's channel number.
|
|
|
override_params (other key word params):
|
|
|
Params to override model's global_params.
|
|
|
Optional key:
|
|
|
'width_coefficient', 'depth_coefficient',
|
|
|
'image_size', 'dropout_rate',
|
|
|
'num_classes', 'batch_norm_momentum',
|
|
|
'batch_norm_epsilon', 'drop_connect_rate',
|
|
|
'depth_divisor', 'min_depth'
|
|
|
|
|
|
Returns:
|
|
|
An efficientnet model.
|
|
|
"""
|
|
|
cls._check_model_name_is_valid(model_name)
|
|
|
blocks_args, global_params = get_model_params(model_name, override_params)
|
|
|
model = cls(blocks_args, global_params)
|
|
|
model._change_in_channels(in_channels)
|
|
|
return model
|
|
|
|
|
|
@classmethod
|
|
|
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
|
|
|
in_channels=3, num_classes=1000, **override_params):
|
|
|
"""create an efficientnet model according to name.
|
|
|
|
|
|
Args:
|
|
|
model_name (str): Name for efficientnet.
|
|
|
weights_path (None or str):
|
|
|
str: path to pretrained weights file on the local disk.
|
|
|
None: use pretrained weights downloaded from the Internet.
|
|
|
advprop (bool):
|
|
|
Whether to load pretrained weights
|
|
|
trained with advprop (valid when weights_path is None).
|
|
|
in_channels (int): Input data's channel number.
|
|
|
num_classes (int):
|
|
|
Number of categories for classification.
|
|
|
It controls the output size for final linear layer.
|
|
|
override_params (other key word params):
|
|
|
Params to override model's global_params.
|
|
|
Optional key:
|
|
|
'width_coefficient', 'depth_coefficient',
|
|
|
'image_size', 'dropout_rate',
|
|
|
'batch_norm_momentum',
|
|
|
'batch_norm_epsilon', 'drop_connect_rate',
|
|
|
'depth_divisor', 'min_depth'
|
|
|
|
|
|
Returns:
|
|
|
A pretrained TRACER-EfficientNet model.
|
|
|
"""
|
|
|
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
|
|
|
load_pretrained_weights(model, model_name, weights_path=weights_path, advprop=advprop)
|
|
|
model._change_in_channels(in_channels)
|
|
|
return model
|
|
|
|
|
|
@classmethod
|
|
|
def get_image_size(cls, model_name):
|
|
|
"""Get the input image size for a given efficientnet model.
|
|
|
|
|
|
Args:
|
|
|
model_name (str): Name for efficientnet.
|
|
|
|
|
|
Returns:
|
|
|
Input image size (resolution).
|
|
|
"""
|
|
|
cls._check_model_name_is_valid(model_name)
|
|
|
_, _, res, _ = efficientnet_params(model_name)
|
|
|
return res
|
|
|
|
|
|
@classmethod
|
|
|
def _check_model_name_is_valid(cls, model_name):
|
|
|
"""Validates model name.
|
|
|
|
|
|
Args:
|
|
|
model_name (str): Name for efficientnet.
|
|
|
|
|
|
Returns:
|
|
|
bool: Is a valid name or not.
|
|
|
"""
|
|
|
if model_name not in VALID_MODELS:
|
|
|
raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
|
|
|
|
|
|
def _change_in_channels(self, in_channels):
|
|
|
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
|
|
|
|
|
|
Args:
|
|
|
in_channels (int): Input data's channel number.
|
|
|
"""
|
|
|
if in_channels != 3:
|
|
|
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
|
|
|
out_channels = round_filters(32, self._global_params)
|
|
|
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
|
|
|