| import torch.nn as nn | |
| import torch.nn.init as init | |
| def weight_init(m, spread=1.0): | |
| ''' | |
| Initializes a model's parameters. | |
| Credits to: https://gist.github.com/jeasinema | |
| Usage: | |
| model = Model() | |
| model.apply(weight_init) | |
| ''' | |
| if isinstance(m, nn.Conv1d): | |
| init.normal_(m.weight.data, mean=0, std=spread) | |
| if m.bias is not None: | |
| init.normal_(m.bias.data, mean=0, std=spread) | |
| elif isinstance(m, nn.Conv2d): | |
| init.xavier_normal_(m.weight.data, gain=spread) | |
| if m.bias is not None: | |
| init.normal_(m.bias.data, mean=0, std=spread) | |
| elif isinstance(m, nn.Conv3d): | |
| init.xavier_normal_(m.weight.data, gain=spread) | |
| if m.bias is not None: | |
| init.normal_(m.bias.data, mean=0, std=spread) | |
| elif isinstance(m, nn.ConvTranspose1d): | |
| init.normal_(m.weight.data, mean=0, std=spread) | |
| if m.bias is not None: | |
| init.normal_(m.bias.data, mean=0, std=spread) | |
| elif isinstance(m, nn.ConvTranspose2d): | |
| init.xavier_normal_(m.weight.data, gain=spread) | |
| if m.bias is not None: | |
| init.normal_(m.bias.data, mean=0, std=spread) | |
| elif isinstance(m, nn.ConvTranspose3d): | |
| init.xavier_normal_(m.weight.data, gain=spread) | |
| if m.bias is not None: | |
| init.normal_(m.bias.data, mean=0, std=spread) | |
| elif isinstance(m, nn.BatchNorm1d): | |
| init.normal_(m.weight.data, mean=0, std=spread) | |
| init.constant_(m.bias.data, 0) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| init.normal_(m.weight.data, mean=0, std=spread) | |
| init.constant_(m.bias.data, 0) | |
| elif isinstance(m, nn.BatchNorm3d): | |
| init.normal_(m.weight.data, mean=0, std=spread) | |
| init.constant_(m.bias.data, 0) | |
| elif isinstance(m, nn.Linear): | |
| init.xavier_normal_(m.weight.data, gain=spread) | |
| try: | |
| init.normal_(m.bias.data, mean=0, std=spread) | |
| except AttributeError: | |
| pass | |
| elif isinstance(m, nn.LSTM): | |
| for param in m.parameters(): | |
| if len(param.shape) >= 2: | |
| init.orthogonal_(param.data) | |
| else: | |
| init.normal_(param.data, mean=0, std=spread) | |
| elif isinstance(m, nn.LSTMCell): | |
| for param in m.parameters(): | |
| if len(param.shape) >= 2: | |
| init.orthogonal_(param.data) | |
| else: | |
| init.normal_(param.data, mean=0, std=spread) | |
| elif isinstance(m, nn.GRU): | |
| for param in m.parameters(): | |
| if len(param.shape) >= 2: | |
| init.orthogonal_(param.data) | |
| else: | |
| init.normal_(param.data, mean=0, std=spread) | |
| elif isinstance(m, nn.GRUCell): | |
| for param in m.parameters(): | |
| if len(param.shape) >= 2: | |
| init.orthogonal_(param.data) | |
| else: | |
| init.normal_(param.data, mean=0, std=spread) |