| |
| |
| |
|
|
| import collections |
| import math |
| import re |
| from functools import partial |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from torch.utils import model_zoo |
|
|
| from backbone import MammothBackbone |
|
|
| url_map = { |
| 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', |
| 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', |
| 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', |
| 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', |
| 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', |
| 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', |
| 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', |
| 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', |
| } |
|
|
|
|
| def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, verbose=True): |
| """Loads pretrained weights from weights path or download using url. |
| Args: |
| model (Module): The whole model of efficientnet. |
| model_name (str): Model name of efficientnet. |
| weights_path (None or str): |
| str: path to pretrained weights file on the local disk. |
| None: use pretrained weights downloaded from the Internet. |
| load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. |
| advprop (bool): Whether to load pretrained weights |
| trained with advprop (valid when weights_path is None). |
| """ |
| if isinstance(weights_path, str): |
| state_dict = torch.load(weights_path) |
| else: |
| |
| url_map_ = url_map |
| state_dict = model_zoo.load_url(url_map_[model_name]) |
|
|
| if load_fc: |
| ret = model.load_state_dict(state_dict, strict=False) |
| assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) |
| else: |
| state_dict.pop('_fc.weight') |
| state_dict.pop('_fc.bias') |
| ret = model.load_state_dict(state_dict, strict=False) |
| assert set(ret.missing_keys) == set( |
| ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) |
| assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) |
|
|
| if verbose: |
| print('Loaded pretrained weights for {}'.format(model_name)) |
|
|
|
|
| _DEFAULT_BLOCKS_ARGS = [ |
| 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', |
| 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', |
| 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', |
| 'r1_k3_s11_e6_i192_o320_se0.25', |
| ] |
|
|
|
|
| def get_width_and_height_from_size(x): |
| """Obtain height and width from x. |
| Args: |
| x (int, tuple or list): Data size. |
| Returns: |
| size: A tuple or list (H,W). |
| """ |
| if isinstance(x, int): |
| return x, x |
| if isinstance(x, (list, tuple)): |
| return x |
| else: |
| raise TypeError() |
|
|
|
|
| def calculate_output_image_size(input_image_size, stride): |
| """Calculates the output image size when using Conv2dSamePadding with a stride. |
| Necessary for static padding. Thanks to mannatsingh for pointing this out. |
| Args: |
| input_image_size (int, tuple or list): Size of input image. |
| stride (int, tuple or list): Conv2d operation's stride. |
| Returns: |
| output_image_size: A list [H,W]. |
| """ |
| if input_image_size is None: |
| return None |
| image_height, image_width = get_width_and_height_from_size(input_image_size) |
| stride = stride if isinstance(stride, int) else stride[0] |
| image_height = int(math.ceil(image_height / stride)) |
| image_width = int(math.ceil(image_width / stride)) |
| return [image_height, image_width] |
|
|
|
|
| def drop_connect(inputs, p, training): |
| """Drop connect. |
| Args: |
| input (tensor: BCWH): Input of this structure. |
| p (float: 0.0~1.0): Probability of drop connection. |
| training (bool): The running mode. |
| Returns: |
| output: Output after drop connection. |
| """ |
| assert 0 <= p <= 1, 'p must be in range of [0,1]' |
|
|
| if not training: |
| return inputs |
|
|
| batch_size = inputs.shape[0] |
| keep_prob = 1 - p |
|
|
| |
| random_tensor = keep_prob |
| random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) |
| binary_tensor = torch.floor(random_tensor) |
|
|
| output = inputs / keep_prob * binary_tensor |
| return output |
|
|
|
|
| def round_repeats(repeats, global_params): |
| """Calculate module's repeat number of a block based on depth multiplier. |
| Use depth_coefficient of global_params. |
| Args: |
| repeats (int): num_repeat to be calculated. |
| global_params (namedtuple): Global params of the model. |
| Returns: |
| new repeat: New repeat number after calculating. |
| """ |
| multiplier = global_params.depth_coefficient |
| if not multiplier: |
| return repeats |
| |
| return int(math.ceil(multiplier * repeats)) |
|
|
|
|
| def round_filters(filters, global_params): |
| """Calculate and round number of filters based on width multiplier. |
| Use width_coefficient, depth_divisor and min_depth of global_params. |
| Args: |
| filters (int): Filters number to be calculated. |
| global_params (namedtuple): Global params of the model. |
| Returns: |
| new_filters: New filters number after calculating. |
| """ |
| multiplier = global_params.width_coefficient |
| if not multiplier: |
| return filters |
| |
| |
| |
| divisor = global_params.depth_divisor |
| min_depth = global_params.min_depth |
| filters *= multiplier |
| min_depth = min_depth or divisor |
| |
| new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) |
| if new_filters < 0.9 * filters: |
| new_filters += divisor |
| return int(new_filters) |
|
|
|
|
| class Conv2dDynamicSamePadding(nn.Conv2d): |
| """2D Convolutions like TensorFlow, for a dynamic image size. |
| The padding is operated in forward function by calculating dynamically. |
| """ |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): |
| super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) |
| self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 |
|
|
| def forward(self, x): |
| ih, iw = x.size()[-2:] |
| kh, kw = self.weight.size()[-2:] |
| sh, sw = self.stride |
| oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) |
| pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) |
| pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) |
| if pad_h > 0 or pad_w > 0: |
| x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) |
| return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) |
|
|
|
|
| class Conv2dStaticSamePadding(nn.Conv2d): |
| """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. |
| The padding mudule is calculated in construction function, then used in forward. |
| """ |
|
|
| |
|
|
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs): |
| super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) |
| self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 |
|
|
| |
| assert image_size is not None |
| ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size |
| kh, kw = self.weight.size()[-2:] |
| sh, sw = self.stride |
| oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) |
| pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) |
| pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) |
| if pad_h > 0 or pad_w > 0: |
| self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, |
| pad_h // 2, pad_h - pad_h // 2)) |
| else: |
| self.static_padding = nn.Identity() |
|
|
| def forward(self, x): |
| x = self.static_padding(x) |
| x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) |
| return x |
|
|
|
|
| def get_same_padding_conv2d(image_size=None): |
| """Chooses static padding if you have specified an image size, and dynamic padding otherwise. |
| Static padding is necessary for ONNX exporting of models. |
| Args: |
| image_size (int or tuple): Size of the image. |
| Returns: |
| Conv2dDynamicSamePadding or Conv2dStaticSamePadding. |
| """ |
| if image_size is None: |
| return Conv2dDynamicSamePadding |
| else: |
| return partial(Conv2dStaticSamePadding, image_size=image_size) |
|
|
| |
| |
| |
| |
| |
|
|
|
|
| GlobalParams = collections.namedtuple('GlobalParams', [ |
| 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format', |
| 'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor', |
| 'min_depth', 'survival_prob', 'relu_fn', 'batch_norm', 'use_se', |
| 'local_pooling', 'condconv_num_experts', 'clip_projection_output', |
| 'blocks_args', 'image_size', 'drop_connect_rate', 'include_top' |
| ]) |
|
|
| |
| |
| |
| |
|
|
| BlockArgs = collections.namedtuple('BlockArgs', [ |
| 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', |
| 'expand_ratio', 'id_skip', 'strides', 'se_ratio', 'conv_type', 'fused_conv', |
| 'super_pixel', 'condconv', 'stride' |
| ]) |
|
|
| |
| GlobalParams.__new__.__defaults__ = (None,) * (len(GlobalParams._fields) - 1) + (True,) |
| BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) |
|
|
| |
| if hasattr(nn, 'SiLU'): |
| Swish = nn.SiLU |
| else: |
| |
| class Swish(nn.Module): |
| def forward(self, x): |
| return x * torch.sigmoid(x) |
|
|
| |
|
|
|
|
| class SwishImplementation(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, i): |
| result = i * torch.sigmoid(i) |
| ctx.save_for_backward(i) |
| return result |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| i = ctx.saved_tensors[0] |
| sigmoid_i = torch.sigmoid(i) |
| return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) |
|
|
|
|
| class MemoryEfficientSwish(nn.Module): |
| def forward(self, x): |
| return SwishImplementation.apply(x) |
|
|
|
|
| class BlockDecoder(object): |
| """Block Decoder for readability, |
| straight from the official TensorFlow repository. |
| """ |
|
|
| @staticmethod |
| def _decode_block_string(block_string): |
| """Get a block through a string notation of arguments. |
| |
| Args: |
| block_string (str): A string notation of arguments. |
| Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. |
| |
| Returns: |
| BlockArgs: The namedtuple defined at the top of this file. |
| """ |
| assert isinstance(block_string, str) |
|
|
| ops = block_string.split('_') |
| options = {} |
| for op in ops: |
| splits = re.split(r'(\d.*)', op) |
| if len(splits) >= 2: |
| key, value = splits[:2] |
| options[key] = value |
|
|
| |
| assert (('s' in options and len(options['s']) == 1) or |
| (len(options['s']) == 2 and options['s'][0] == options['s'][1])) |
|
|
| return BlockArgs( |
| num_repeat=int(options['r']), |
| kernel_size=int(options['k']), |
| stride=[int(options['s'][0])], |
| expand_ratio=int(options['e']), |
| input_filters=int(options['i']), |
| output_filters=int(options['o']), |
| se_ratio=float(options['se']) if 'se' in options else None, |
| id_skip=('noskip' not in block_string)) |
|
|
| @staticmethod |
| def _encode_block_string(block): |
| """Encode a block to a string. |
| |
| Args: |
| block (namedtuple): A BlockArgs type argument. |
| |
| Returns: |
| block_string: A String form of BlockArgs. |
| """ |
| args = [ |
| 'r%d' % block.num_repeat, |
| 'k%d' % block.kernel_size, |
| 's%d%d' % (block.strides[0], block.strides[1]), |
| 'e%s' % block.expand_ratio, |
| 'i%d' % block.input_filters, |
| 'o%d' % block.output_filters |
| ] |
| if 0 < block.se_ratio <= 1: |
| args.append('se%s' % block.se_ratio) |
| if block.id_skip is False: |
| args.append('noskip') |
| return '_'.join(args) |
|
|
| @staticmethod |
| def decode(string_list): |
| """Decode a list of string notations to specify blocks inside the network. |
| |
| Args: |
| string_list (list[str]): A list of strings, each string is a notation of block. |
| |
| Returns: |
| blocks_args: A list of BlockArgs namedtuples of block args. |
| """ |
| assert isinstance(string_list, list) |
| blocks_args = [] |
| for block_string in string_list: |
| blocks_args.append(BlockDecoder._decode_block_string(block_string)) |
| return blocks_args |
|
|
| @staticmethod |
| def encode(blocks_args): |
| """Encode a list of BlockArgs to a list of strings. |
| |
| Args: |
| blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. |
| |
| Returns: |
| block_strings: A list of strings, each string is a notation of block. |
| """ |
| block_strings = [] |
| for block in blocks_args: |
| block_strings.append(BlockDecoder._encode_block_string(block)) |
| return block_strings |
|
|
|
|
| def efficientnet_params(model_name): |
| """Map EfficientNet model name to parameter coefficients. |
| |
| Args: |
| model_name (str): Model name to be queried. |
| |
| Returns: |
| params_dict[model_name]: A (width,depth,res,dropout) tuple. |
| """ |
| params_dict = { |
| |
| 'efficientnet-b0': (1.0, 1.0, 224, 0.2), |
| 'efficientnet-b1': (1.0, 1.1, 240, 0.2), |
| 'efficientnet-b2': (1.1, 1.2, 260, 0.3), |
| 'efficientnet-b3': (1.2, 1.4, 300, 0.3), |
| 'efficientnet-b4': (1.4, 1.8, 380, 0.4), |
| 'efficientnet-b5': (1.6, 2.2, 456, 0.4), |
| 'efficientnet-b6': (1.8, 2.6, 528, 0.5), |
| 'efficientnet-b7': (2.0, 3.1, 600, 0.5), |
| 'efficientnet-b8': (2.2, 3.6, 672, 0.5), |
| 'efficientnet-l2': (4.3, 5.3, 800, 0.5), |
| } |
| return params_dict[model_name] |
|
|
|
|
| def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None, |
| dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True): |
| """Create BlockArgs and GlobalParams for efficientnet model. |
| |
| Args: |
| width_coefficient (float) |
| depth_coefficient (float) |
| image_size (int) |
| dropout_rate (float) |
| drop_connect_rate (float) |
| num_classes (int) |
| |
| Meaning as the name suggests. |
| |
| Returns: |
| blocks_args, global_params. |
| """ |
|
|
| |
| |
| blocks_args = [ |
| 'r1_k3_s11_e1_i32_o16_se0.25', |
| 'r2_k3_s22_e6_i16_o24_se0.25', |
| 'r2_k5_s22_e6_i24_o40_se0.25', |
| 'r3_k3_s22_e6_i40_o80_se0.25', |
| 'r3_k5_s11_e6_i80_o112_se0.25', |
| 'r4_k5_s22_e6_i112_o192_se0.25', |
| 'r1_k3_s11_e6_i192_o320_se0.25', |
| ] |
| blocks_args = BlockDecoder.decode(blocks_args) |
|
|
| global_params = GlobalParams( |
| width_coefficient=width_coefficient, |
| depth_coefficient=depth_coefficient, |
| image_size=image_size, |
| dropout_rate=dropout_rate, |
|
|
| num_classes=num_classes, |
| batch_norm_momentum=0.99, |
| batch_norm_epsilon=1e-3, |
| drop_connect_rate=drop_connect_rate, |
| depth_divisor=8, |
| min_depth=None, |
| include_top=include_top, |
| ) |
|
|
| return blocks_args, global_params |
|
|
|
|
| def efficientnet_tf(width_coefficient=None, |
| depth_coefficient=None, |
| dropout_rate=0.2, |
| survival_prob=0.8): |
| """Creates a efficientnet model.""" |
| global_params = GlobalParams( |
| blocks_args=_DEFAULT_BLOCKS_ARGS, |
| batch_norm_momentum=0.99, |
| batch_norm_epsilon=1e-3, |
| dropout_rate=dropout_rate, |
| survival_prob=survival_prob, |
| data_format='channels_last', |
| num_classes=1000, |
| width_coefficient=width_coefficient, |
| depth_coefficient=depth_coefficient, |
| depth_divisor=8, |
| min_depth=None, |
| use_se=True, |
| clip_projection_output=False) |
| return global_params |
|
|
|
|
| def get_model_params_tf(model_name, override_params): |
| """Get the block args and global params for a given model.""" |
| if model_name.startswith('efficientnet'): |
| width_coefficient, depth_coefficient, _, dropout_rate = ( |
| efficientnet_params(model_name)) |
| global_params = efficientnet( |
| width_coefficient, depth_coefficient, dropout_rate) |
| else: |
| raise NotImplementedError('model name is not pre-defined: %s' % model_name) |
|
|
| if override_params: |
| |
| |
| global_params = global_params._replace(**override_params) |
|
|
| decoder = BlockDecoder() |
| blocks_args = decoder.decode(global_params.blocks_args) |
|
|
| print('EFFNET LOGGING: global_params= %s', global_params) |
| return blocks_args, global_params |
|
|
|
|
| def get_model_params(model_name, override_params): |
| """Get the block args and global params for a given model name. |
| |
| Args: |
| model_name (str): Model's name. |
| override_params (dict): A dict to modify global_params. |
| |
| Returns: |
| blocks_args, global_params |
| """ |
| if model_name.startswith('efficientnet'): |
| w, d, s, p = efficientnet_params(model_name) |
| |
| blocks_args, global_params = efficientnet( |
| width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) |
| else: |
| raise NotImplementedError('model name is not pre-defined: {}'.format(model_name)) |
| if override_params: |
| |
| global_params = global_params._replace(**override_params) |
| return blocks_args, global_params |
|
|
|
|
| def efficientnet_params(model_name): |
| """Get efficientnet params based on model name.""" |
| params_dict = { |
| |
| 'efficientnet-b0': (1.0, 1.0, 224, 0.2), |
| 'efficientnet-b1': (1.0, 1.1, 240, 0.2), |
| 'efficientnet-b2': (1.1, 1.2, 260, 0.3), |
| 'efficientnet-b3': (1.2, 1.4, 300, 0.3), |
| 'efficientnet-b4': (1.4, 1.8, 380, 0.4), |
| 'efficientnet-b5': (1.6, 2.2, 456, 0.4), |
| 'efficientnet-b6': (1.8, 2.6, 528, 0.5), |
| 'efficientnet-b7': (2.0, 3.1, 600, 0.5), |
| 'efficientnet-b8': (2.2, 3.6, 672, 0.5), |
| 'efficientnet-l2': (4.3, 5.3, 800, 0.5), |
| } |
| return params_dict[model_name] |
|
|
|
|
| VALID_MODELS = ( |
| 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', |
| 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7', |
| 'efficientnet-b8', |
| ) |
|
|
|
|
| class MBConvBlock(nn.Module): |
| """Mobile Inverted Residual Bottleneck Block. |
| |
| Args: |
| block_args (namedtuple): BlockArgs, defined in utils.py. |
| global_params (namedtuple): GlobalParam, defined in utils.py. |
| image_size (tuple or list): [image_height, image_width]. |
| |
| References: |
| [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) |
| [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) |
| [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) |
| """ |
|
|
| def __init__(self, block_args, global_params, image_size=None): |
| super().__init__() |
| self._block_args = block_args |
| self._bn_mom = 1 - global_params.batch_norm_momentum |
| self._bn_eps = global_params.batch_norm_epsilon |
| self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) |
| self.id_skip = block_args.id_skip |
|
|
| |
| inp = self._block_args.input_filters |
| oup = self._block_args.input_filters * self._block_args.expand_ratio |
| if self._block_args.expand_ratio != 1: |
| Conv2d = get_same_padding_conv2d(image_size=image_size) |
| self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) |
| self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) |
| |
|
|
| |
| k = self._block_args.kernel_size |
| s = self._block_args.stride |
| Conv2d = get_same_padding_conv2d(image_size=image_size) |
| self._depthwise_conv = Conv2d( |
| in_channels=oup, out_channels=oup, groups=oup, |
| kernel_size=k, stride=s, bias=False) |
| self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) |
| image_size = calculate_output_image_size(image_size, s) |
|
|
| |
| if self.has_se: |
| Conv2d = get_same_padding_conv2d(image_size=(1, 1)) |
| num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) |
| self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) |
| self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) |
|
|
| |
| final_oup = self._block_args.output_filters |
| Conv2d = get_same_padding_conv2d(image_size=image_size) |
| self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) |
| self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) |
| self._swish = MemoryEfficientSwish() |
|
|
| def forward(self, inputs, drop_connect_rate=None): |
| """MBConvBlock's forward function. |
| |
| Args: |
| inputs (tensor): Input tensor. |
| drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). |
| |
| Returns: |
| Output of this block after processing. |
| """ |
|
|
| |
| x = inputs |
| if self._block_args.expand_ratio != 1: |
| x = self._expand_conv(inputs) |
| x = self._bn0(x) |
| x = self._swish(x) |
|
|
| x = self._depthwise_conv(x) |
| x = self._bn1(x) |
| x = self._swish(x) |
|
|
| |
| if self.has_se: |
| x_squeezed = F.adaptive_avg_pool2d(x, 1) |
| x_squeezed = self._se_reduce(x_squeezed) |
| x_squeezed = self._swish(x_squeezed) |
| x_squeezed = self._se_expand(x_squeezed) |
| x = torch.sigmoid(x_squeezed) * x |
|
|
| |
| x = self._project_conv(x) |
| x = self._bn2(x) |
|
|
| |
| input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters |
| if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: |
| |
| if drop_connect_rate: |
| x = drop_connect(x, p=drop_connect_rate, training=self.training) |
| x = x + inputs |
| return x |
|
|
| def set_swish(self, memory_efficient=True): |
| """Sets swish function as memory efficient (for training) or standard (for export). |
| |
| Args: |
| memory_efficient (bool): Whether to use memory-efficient version of swish. |
| """ |
| self._swish = MemoryEfficientSwish() if memory_efficient else Swish() |
|
|
|
|
| class EfficientNet(MammothBackbone): |
| """EfficientNet model. |
| Most easily loaded with the .from_name or .from_pretrained methods. |
| |
| Args: |
| blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. |
| global_params (namedtuple): A set of GlobalParams shared between blocks. |
| |
| References: |
| [1] https://arxiv.org/abs/1905.11946 (EfficientNet) |
| |
| Example: |
| >>> import torch |
| >>> from efficientnet.model import EfficientNet |
| >>> inputs = torch.rand(1, 3, 224, 224) |
| >>> model = EfficientNet.from_pretrained('efficientnet-b0') |
| >>> model.eval() |
| >>> outputs = model(inputs) |
| """ |
|
|
| def __init__(self, blocks_args=None, global_params=None, hookme=False): |
| super().__init__() |
| assert isinstance(blocks_args, list), 'blocks_args should be a list' |
| assert len(blocks_args) > 0, 'block args must be greater than 0' |
| self._global_params = global_params |
| self.hookme = hookme |
| self._blocks_args = blocks_args |
|
|
| |
| bn_mom = 1 - self._global_params.batch_norm_momentum |
| bn_eps = self._global_params.batch_norm_epsilon |
|
|
| |
| image_size = global_params.image_size |
| Conv2d = get_same_padding_conv2d(image_size=image_size) |
|
|
| |
| in_channels = 3 |
| out_channels = round_filters(32, self._global_params) |
| self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) |
| self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
| image_size = calculate_output_image_size(image_size, 2) |
|
|
| |
| self._blocks = nn.ModuleList([]) |
| for block_args in self._blocks_args: |
|
|
| |
| block_args = block_args._replace( |
| input_filters=round_filters(block_args.input_filters, self._global_params), |
| output_filters=round_filters(block_args.output_filters, self._global_params), |
| num_repeat=round_repeats(block_args.num_repeat, self._global_params) |
| ) |
|
|
| |
| self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) |
| image_size = calculate_output_image_size(image_size, block_args.stride) |
| if block_args.num_repeat > 1: |
| block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) |
| for _ in range(block_args.num_repeat - 1): |
| self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) |
| |
|
|
| |
| in_channels = block_args.output_filters |
| out_channels = round_filters(1280, self._global_params) |
| Conv2d = get_same_padding_conv2d(image_size=image_size) |
| self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) |
| self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) |
|
|
| |
| self._avg_pooling = nn.AdaptiveAvgPool2d(1) |
| if self._global_params.include_top: |
| self._dropout = nn.Dropout(self._global_params.dropout_rate) |
| self.classifier = nn.Linear(out_channels, self._global_params.num_classes) |
|
|
| |
| self._swish = MemoryEfficientSwish() |
|
|
| def set_swish(self, memory_efficient=True): |
| """Sets swish function as memory efficient (for training) or standard (for export). |
| |
| Args: |
| memory_efficient (bool): Whether to use memory-efficient version of swish. |
| """ |
| self._swish = MemoryEfficientSwish() if memory_efficient else Swish() |
| for block in self._blocks: |
| block.set_swish(memory_efficient) |
|
|
| def extract_endpoints(self, inputs): |
| """Use convolution layer to extract features |
| from reduction levels i in [1, 2, 3, 4, 5]. |
| |
| Args: |
| inputs (tensor): Input tensor. |
| |
| Returns: |
| Dictionary of last intermediate features |
| with reduction levels i in [1, 2, 3, 4, 5]. |
| Example: |
| >>> import torch |
| >>> from efficientnet.model import EfficientNet |
| >>> inputs = torch.rand(1, 3, 224, 224) |
| >>> model = EfficientNet.from_pretrained('efficientnet-b0') |
| >>> endpoints = model.extract_endpoints(inputs) |
| >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) |
| >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) |
| >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) |
| >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) |
| >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) |
| >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) |
| """ |
| endpoints = {} |
|
|
| |
| x = self._swish(self._bn0(self._conv_stem(inputs))) |
| prev_x = x |
|
|
| |
| for idx, block in enumerate(self._blocks): |
| drop_connect_rate = self._global_params.drop_connect_rate |
| if drop_connect_rate: |
| drop_connect_rate *= float(idx) / len(self._blocks) |
| x = block(x, drop_connect_rate=drop_connect_rate) |
| if prev_x.size(2) > x.size(2): |
| endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x |
| elif idx == len(self._blocks) - 1: |
| endpoints['reduction_{}'.format(len(endpoints) + 1)] = x |
| prev_x = x |
|
|
| |
| x = self._swish(self._bn1(self._conv_head(x))) |
| endpoints['reduction_{}'.format(len(endpoints) + 1)] = x |
|
|
| return endpoints |
|
|
| def extract_features(self, inputs): |
| """use convolution layer to extract feature . |
| |
| Args: |
| inputs (tensor): Input tensor. |
| |
| Returns: |
| Output of the final convolution |
| layer in the efficientnet model. |
| """ |
| |
| x = self._swish(self._bn0(self._conv_stem(inputs))) |
|
|
| |
| for idx, block in enumerate(self._blocks): |
| drop_connect_rate = self._global_params.drop_connect_rate |
| if drop_connect_rate: |
| drop_connect_rate *= float(idx) / len(self._blocks) |
| x = block(x, drop_connect_rate=drop_connect_rate) |
|
|
| |
| x = self._swish(self._bn1(self._conv_head(x))) |
|
|
| return x |
|
|
| def activations_hook(self, grad): |
| self.gradients = grad |
|
|
| def forward(self, inputs, returnt='out'): |
| """EfficientNet's forward function. |
| Calls extract_features to extract features, applies final linear layer, and returns logits. |
| |
| Args: |
| inputs (tensor): Input tensor. |
| |
| Returns: |
| Output of this model after processing. |
| """ |
| |
| x = self.extract_features(inputs) |
|
|
| |
| feats = self._avg_pooling(x).flatten(start_dim=1) |
| if returnt == 'features': |
| return feats |
| if self._global_params.include_top: |
| x = self._dropout(feats) |
| x = self.classifier(x) |
| if returnt == 'out': |
| return x |
| elif returnt == 'all': |
| return (x, feats) |
|
|
| raise NotImplementedError("Unknown return type") |
|
|
| @classmethod |
| def from_name(cls, model_name, in_channels=3, **override_params): |
| """Create an efficientnet model according to name. |
| |
| Args: |
| model_name (str): Name for efficientnet. |
| in_channels (int): Input data's channel number. |
| override_params (other key word params): |
| Params to override model's global_params. |
| Optional key: |
| 'width_coefficient', 'depth_coefficient', |
| 'image_size', 'dropout_rate', |
| 'num_classes', 'batch_norm_momentum', |
| 'batch_norm_epsilon', 'drop_connect_rate', |
| 'depth_divisor', 'min_depth' |
| |
| Returns: |
| An efficientnet model. |
| """ |
| cls._check_model_name_is_valid(model_name) |
| blocks_args, global_params = get_model_params(model_name, override_params) |
| model = cls(blocks_args, global_params) |
| model._change_in_channels(in_channels) |
| return model |
|
|
| @classmethod |
| def from_pretrained(cls, model_name, weights_path=None, advprop=False, |
| in_channels=3, num_classes=1000, **override_params): |
| """Create an efficientnet model according to name. |
| |
| Args: |
| model_name (str): Name for efficientnet. |
| weights_path (None or str): |
| str: path to pretrained weights file on the local disk. |
| None: use pretrained weights downloaded from the Internet. |
| advprop (bool): |
| Whether to load pretrained weights |
| trained with advprop (valid when weights_path is None). |
| in_channels (int): Input data's channel number. |
| num_classes (int): |
| Number of categories for classification. |
| It controls the output size for final linear layer. |
| override_params (other key word params): |
| Params to override model's global_params. |
| Optional key: |
| 'width_coefficient', 'depth_coefficient', |
| 'image_size', 'dropout_rate', |
| 'batch_norm_momentum', |
| 'batch_norm_epsilon', 'drop_connect_rate', |
| 'depth_divisor', 'min_depth' |
| |
| Returns: |
| A pretrained efficientnet model. |
| """ |
| model = cls.from_name(model_name, num_classes=num_classes, **override_params) |
| load_pretrained_weights(model, model_name, weights_path=weights_path, |
| load_fc=(num_classes == 1000), advprop=advprop) |
| model._change_in_channels(in_channels) |
| return model |
|
|
| @classmethod |
| def get_image_size(cls, model_name): |
| """Get the input image size for a given efficientnet model. |
| |
| Args: |
| model_name (str): Name for efficientnet. |
| |
| Returns: |
| Input image size (resolution). |
| """ |
| cls._check_model_name_is_valid(model_name) |
| _, _, res, _ = efficientnet_params(model_name) |
| return res |
|
|
| @classmethod |
| def _check_model_name_is_valid(cls, model_name): |
| """Validates model name. |
| |
| Args: |
| model_name (str): Name for efficientnet. |
| |
| Returns: |
| bool: Is a valid name or not. |
| """ |
| if model_name not in VALID_MODELS: |
| raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS)) |
|
|
| def _change_in_channels(self, in_channels): |
| """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. |
| |
| Args: |
| in_channels (int): Input data's channel number. |
| """ |
| if in_channels != 3: |
| Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) |
| out_channels = round_filters(32, self._global_params) |
| self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def mammoth_efficientnet(nclasses: int, model_name: str, pretrained=False): |
| """ |
| Instantiates a ResNet18 network. |
| :param nclasses: number of output classes |
| :param nf: number of filters |
| :return: ResNet network |
| """ |
| print(model_name) |
| if not pretrained: |
| return EfficientNet.from_name(model_name=model_name, num_classes=nclasses) |
| else: |
| return EfficientNet.from_pretrained(model_name=model_name, num_classes=nclasses) |
|
|