| """Utility functions for initializing weights and biases.""" |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
|
|
| import torch |
|
|
|
|
| def _calculate_fan(linear_weight_shape, fan="fan_in"): |
| fan_out, fan_in = linear_weight_shape |
|
|
| if fan == "fan_in": |
| f = fan_in |
| elif fan == "fan_out": |
| f = fan_out |
| elif fan == "fan_avg": |
| f = (fan_in + fan_out) / 2 |
| else: |
| raise ValueError("Invalid fan option") |
|
|
| return f |
|
|
|
|
| def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): |
| shape = weights.shape |
| f = _calculate_fan(shape, fan) |
| scale = scale / max(1, f) |
| std = math.sqrt(scale) |
| with torch.no_grad(): |
| torch.nn.init.trunc_normal_(weights, mean=0.0, std=std, a=-2 * std, b=2 * std) |
|
|
|
|
| def lecun_normal_init_(weights): |
| trunc_normal_init_(weights, scale=1.0) |
|
|
|
|
| def he_normal_init_(weights): |
| trunc_normal_init_(weights, scale=2.0) |
|
|
|
|
| def glorot_uniform_init_(weights): |
| torch.nn.init.xavier_uniform_(weights, gain=1) |
|
|
|
|
| def final_init_(weights): |
| with torch.no_grad(): |
| weights.fill_(0.0) |
|
|
|
|
| def gating_init_(weights): |
| with torch.no_grad(): |
| weights.fill_(0.0) |
|
|
|
|
| def bias_init_zero_(bias): |
| with torch.no_grad(): |
| bias.fill_(0.0) |
|
|
|
|
| def bias_init_one_(bias): |
| with torch.no_grad(): |
| bias.fill_(1.0) |
|
|
|
|
| def normal_init_(weights): |
| torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") |
|
|
|
|
| def ipa_point_weights_init_(weights): |
| with torch.no_grad(): |
| softplus_inverse_1 = 0.541324854612918 |
| weights.fill_(softplus_inverse_1) |
|
|