| """ Attention Factory |
| |
| Hacked together by / Copyright 2021 Ross Wightman |
| """ |
| import torch |
| from functools import partial |
|
|
| from .bottleneck_attn import BottleneckAttn |
| from .cbam import CbamModule, LightCbamModule |
| from .eca import EcaModule, CecaModule |
| from .gather_excite import GatherExcite |
| from .global_context import GlobalContext |
| from .halo_attn import HaloAttn |
| from .involution import Involution |
| from .lambda_layer import LambdaLayer |
| from .non_local_attn import NonLocalAttn, BatNonLocalAttn |
| from .selective_kernel import SelectiveKernel |
| from .split_attn import SplitAttn |
| from .squeeze_excite import SEModule, EffectiveSEModule |
| from .swin_attn import WindowAttention |
|
|
|
|
| def get_attn(attn_type): |
| if isinstance(attn_type, torch.nn.Module): |
| return attn_type |
| module_cls = None |
| if attn_type is not None: |
| if isinstance(attn_type, str): |
| attn_type = attn_type.lower() |
| |
| |
| if attn_type == 'se': |
| module_cls = SEModule |
| elif attn_type == 'ese': |
| module_cls = EffectiveSEModule |
| elif attn_type == 'eca': |
| module_cls = EcaModule |
| elif attn_type == 'ecam': |
| module_cls = partial(EcaModule, use_mlp=True) |
| elif attn_type == 'ceca': |
| module_cls = CecaModule |
| elif attn_type == 'ge': |
| module_cls = GatherExcite |
| elif attn_type == 'gc': |
| module_cls = GlobalContext |
| elif attn_type == 'cbam': |
| module_cls = CbamModule |
| elif attn_type == 'lcbam': |
| module_cls = LightCbamModule |
|
|
| |
| |
| |
| elif attn_type == 'sk': |
| module_cls = SelectiveKernel |
| elif attn_type == 'splat': |
| module_cls = SplitAttn |
|
|
| |
| |
| |
| elif attn_type == 'lambda': |
| return LambdaLayer |
| elif attn_type == 'bottleneck': |
| return BottleneckAttn |
| elif attn_type == 'halo': |
| return HaloAttn |
| elif attn_type == 'swin': |
| return WindowAttention |
| elif attn_type == 'involution': |
| return Involution |
| elif attn_type == 'nl': |
| module_cls = NonLocalAttn |
| elif attn_type == 'bat': |
| module_cls = BatNonLocalAttn |
|
|
| |
| else: |
| assert False, "Invalid attn module (%s)" % attn_type |
| elif isinstance(attn_type, bool): |
| if attn_type: |
| module_cls = SEModule |
| else: |
| module_cls = attn_type |
| return module_cls |
|
|
|
|
| def create_attn(attn_type, channels, **kwargs): |
| module_cls = get_attn(attn_type) |
| if module_cls is not None: |
| |
| return module_cls(channels, **kwargs) |
| return None |
|
|