| | |
| | import torch.nn as nn |
| |
|
| | from .registry import PADDING_LAYERS |
| |
|
| | PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d) |
| | PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d) |
| | PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d) |
| |
|
| |
|
| | def build_padding_layer(cfg, *args, **kwargs): |
| | """Build padding layer. |
| | |
| | Args: |
| | cfg (None or dict): The padding layer config, which should contain: |
| | - type (str): Layer type. |
| | - layer args: Args needed to instantiate a padding layer. |
| | |
| | Returns: |
| | nn.Module: Created padding layer. |
| | """ |
| | if not isinstance(cfg, dict): |
| | raise TypeError('cfg must be a dict') |
| | if 'type' not in cfg: |
| | raise KeyError('the cfg dict must contain the key "type"') |
| |
|
| | cfg_ = cfg.copy() |
| | padding_type = cfg_.pop('type') |
| | if padding_type not in PADDING_LAYERS: |
| | raise KeyError(f'Unrecognized padding type {padding_type}.') |
| | else: |
| | padding_layer = PADDING_LAYERS.get(padding_type) |
| |
|
| | layer = padding_layer(*args, **kwargs, **cfg_) |
| |
|
| | return layer |
| |
|