| | """ NormAct (Normalizaiton + Activation Layer) Factory |
| | |
| | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act |
| | isntances in models. Where these are used it will be possible to swap separate BN + act layers with |
| | combined modules like IABN or EvoNorms. |
| | |
| | Hacked together by / Copyright 2020 Ross Wightman |
| | """ |
| | import types |
| | import functools |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .evo_norm import EvoNormBatch2d, EvoNormSample2d |
| | from .norm_act import BatchNormAct2d, GroupNormAct |
| | from .inplace_abn import InplaceAbn |
| |
|
| | _NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} |
| | _NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} |
| |
|
| |
|
| | def get_norm_act_layer(layer_class): |
| | layer_class = layer_class.replace('_', '').lower() |
| | if layer_class.startswith("batchnorm"): |
| | layer = BatchNormAct2d |
| | elif layer_class.startswith("groupnorm"): |
| | layer = GroupNormAct |
| | elif layer_class == "evonormbatch": |
| | layer = EvoNormBatch2d |
| | elif layer_class == "evonormsample": |
| | layer = EvoNormSample2d |
| | elif layer_class == "iabn" or layer_class == "inplaceabn": |
| | layer = InplaceAbn |
| | else: |
| | assert False, "Invalid norm_act layer (%s)" % layer_class |
| | return layer |
| |
|
| |
|
| | def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): |
| | layer_parts = layer_type.split('-') |
| | assert len(layer_parts) in (1, 2) |
| | layer = get_norm_act_layer(layer_parts[0]) |
| | |
| | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) |
| | if jit: |
| | layer_instance = torch.jit.script(layer_instance) |
| | return layer_instance |
| |
|
| |
|
| | def convert_norm_act(norm_layer, act_layer): |
| | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) |
| | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) |
| | norm_act_kwargs = {} |
| |
|
| | |
| | if isinstance(norm_layer, functools.partial): |
| | norm_act_kwargs.update(norm_layer.keywords) |
| | norm_layer = norm_layer.func |
| |
|
| | if isinstance(norm_layer, str): |
| | norm_act_layer = get_norm_act_layer(norm_layer) |
| | elif norm_layer in _NORM_ACT_TYPES: |
| | norm_act_layer = norm_layer |
| | elif isinstance(norm_layer, types.FunctionType): |
| | |
| | norm_act_layer = norm_layer |
| | else: |
| | type_name = norm_layer.__name__.lower() |
| | if type_name.startswith('batchnorm'): |
| | norm_act_layer = BatchNormAct2d |
| | elif type_name.startswith('groupnorm'): |
| | norm_act_layer = GroupNormAct |
| | else: |
| | assert False, f"No equivalent norm_act layer for {type_name}" |
| |
|
| | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: |
| | |
| | |
| | norm_act_kwargs.setdefault('act_layer', act_layer) |
| | if norm_act_kwargs: |
| | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) |
| | return norm_act_layer |
| |
|