| | |
| | from torch import nn |
| |
|
| | from .registry import CONV_LAYERS |
| |
|
| | CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d) |
| | CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d) |
| | CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d) |
| | CONV_LAYERS.register_module('Conv', module=nn.Conv2d) |
| |
|
| |
|
| | def build_conv_layer(cfg, *args, **kwargs): |
| | """Build convolution layer. |
| | |
| | Args: |
| | cfg (None or dict): The conv layer config, which should contain: |
| | - type (str): Layer type. |
| | - layer args: Args needed to instantiate an conv layer. |
| | args (argument list): Arguments passed to the `__init__` |
| | method of the corresponding conv layer. |
| | kwargs (keyword arguments): Keyword arguments passed to the `__init__` |
| | method of the corresponding conv layer. |
| | |
| | Returns: |
| | nn.Module: Created conv layer. |
| | """ |
| | if cfg is None: |
| | cfg_ = dict(type='Conv2d') |
| | else: |
| | if not isinstance(cfg, dict): |
| | raise TypeError('cfg must be a dict') |
| | if 'type' not in cfg: |
| | raise KeyError('the cfg dict must contain the key "type"') |
| | cfg_ = cfg.copy() |
| |
|
| | layer_type = cfg_.pop('type') |
| | if layer_type not in CONV_LAYERS: |
| | raise KeyError(f'Unrecognized norm type {layer_type}') |
| | else: |
| | conv_layer = CONV_LAYERS.get(layer_type) |
| |
|
| | layer = conv_layer(*args, **kwargs, **cfg_) |
| |
|
| | return layer |
| |
|