| | |
| | """Initialize modules for espnet2 neural networks.""" |
| | import torch |
| | from typeguard import check_argument_types |
| |
|
| |
|
| | def initialize(model: torch.nn.Module, init: str): |
| | """Initialize weights of a neural network module. |
| | |
| | Parameters are initialized using the given method or distribution. |
| | |
| | Custom initialization routines can be implemented into submodules |
| | as function `espnet_initialization_fn` within the custom module. |
| | |
| | Args: |
| | model: Target. |
| | init: Method of initialization. |
| | """ |
| | assert check_argument_types() |
| | print("init with", init) |
| |
|
| | |
| | for p in model.parameters(): |
| | if p.dim() > 1: |
| | if init == "xavier_uniform": |
| | torch.nn.init.xavier_uniform_(p.data) |
| | elif init == "xavier_normal": |
| | torch.nn.init.xavier_normal_(p.data) |
| | elif init == "kaiming_uniform": |
| | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") |
| | elif init == "kaiming_normal": |
| | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") |
| | else: |
| | raise ValueError("Unknown initialization: " + init) |
| | |
| | for name, p in model.named_parameters(): |
| | if ".bias" in name and p.dim() == 1: |
| | p.data.zero_() |
| |
|