Spaces:
Running
Running
| import torch | |
| from torch.nn import init | |
| from scipy.special import gamma as Gamma | |
| from scipy.stats import gennorm | |
| import numpy as np | |
| def gg_init(model, shape=2, xi=2): | |
| """Generalized Gaussian Initialization for ReLU""" | |
| # shape for the shape of parameter distribution | |
| # xi = 1 for Sigmoid or no activation | |
| # xi = 2 for ReLU | |
| # xi = 2 / (1 + k^2) for LeakyReLU | |
| with torch.no_grad(): | |
| for name, param in model.named_parameters(): | |
| param_device = param.device | |
| param_dtype = param.dtype | |
| if len(param.shape) == 2: | |
| n_dim = param.shape[0] | |
| alpha = np.sqrt(xi/n_dim*Gamma(1/shape) / Gamma(3/shape)) | |
| gennorm_params = gennorm.rvs( | |
| shape, loc=0, scale=alpha, size=param.shape) | |
| param.data = torch.from_numpy(gennorm_params) | |
| else: | |
| if "weight" in name: | |
| param.data = torch.ones(param.shape) | |
| elif "bias" in name: | |
| param.data = torch.zeros(param.shape) | |
| param.data = param.data.to(param_dtype).to(param_device) | |