|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
def constant_init(module, val, bias=0): |
|
|
nn.init.constant_(module.weight, val) |
|
|
if hasattr(module, 'bias') and module.bias is not None: |
|
|
nn.init.constant_(module.bias, bias) |
|
|
|
|
|
|
|
|
def xavier_init(module, gain=1, bias=0, distribution='normal'): |
|
|
assert distribution in ['uniform', 'normal'] |
|
|
if distribution == 'uniform': |
|
|
nn.init.xavier_uniform_(module.weight, gain=gain) |
|
|
else: |
|
|
nn.init.xavier_normal_(module.weight, gain=gain) |
|
|
if hasattr(module, 'bias') and module.bias is not None: |
|
|
nn.init.constant_(module.bias, bias) |
|
|
|
|
|
|
|
|
def normal_init(module, mean=0, std=1, bias=0): |
|
|
nn.init.normal_(module.weight, mean, std) |
|
|
if hasattr(module, 'bias') and module.bias is not None: |
|
|
nn.init.constant_(module.bias, bias) |
|
|
|
|
|
|
|
|
def uniform_init(module, a=0, b=1, bias=0): |
|
|
nn.init.uniform_(module.weight, a, b) |
|
|
if hasattr(module, 'bias') and module.bias is not None: |
|
|
nn.init.constant_(module.bias, bias) |
|
|
|
|
|
|
|
|
def kaiming_init(module, |
|
|
a=0, |
|
|
mode='fan_out', |
|
|
nonlinearity='relu', |
|
|
bias=0, |
|
|
distribution='normal'): |
|
|
assert distribution in ['uniform', 'normal'] |
|
|
if distribution == 'uniform': |
|
|
nn.init.kaiming_uniform_(module.weight, |
|
|
a=a, |
|
|
mode=mode, |
|
|
nonlinearity=nonlinearity) |
|
|
else: |
|
|
nn.init.kaiming_normal_(module.weight, |
|
|
a=a, |
|
|
mode=mode, |
|
|
nonlinearity=nonlinearity) |
|
|
if hasattr(module, 'bias') and module.bias is not None: |
|
|
nn.init.constant_(module.bias, bias) |
|
|
|
|
|
|
|
|
def caffe2_xavier_init(module, bias=0): |
|
|
|
|
|
|
|
|
kaiming_init(module, |
|
|
a=1, |
|
|
mode='fan_in', |
|
|
nonlinearity='leaky_relu', |
|
|
bias=bias, |
|
|
distribution='uniform') |
|
|
|
|
|
|
|
|
def c2_xavier_fill(module: nn.Module): |
|
|
""" |
|
|
Initialize `module.weight` using the "XavierFill" implemented in Caffe2. |
|
|
Also initializes `module.bias` to 0. |
|
|
|
|
|
Args: |
|
|
module (torch.nn.Module): module to initialize. |
|
|
""" |
|
|
|
|
|
|
|
|
nn.init.kaiming_uniform_(module.weight, a=1) |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
|
|
|
def c2_msra_fill(module: nn.Module): |
|
|
""" |
|
|
Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. |
|
|
Also initializes `module.bias` to 0. |
|
|
|
|
|
Args: |
|
|
module (torch.nn.Module): module to initialize. |
|
|
""" |
|
|
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
|
|
|
def init_weights(m: nn.Module, zero_init_final_gamma=False): |
|
|
"""Performs ResNet-style weight initialization.""" |
|
|
if isinstance(m, nn.Conv2d): |
|
|
|
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
|
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) |
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
|
zero_init_gamma = ( |
|
|
hasattr(m, "final_bn") and m.final_bn and zero_init_final_gamma |
|
|
) |
|
|
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) |
|
|
m.bias.data.zero_() |
|
|
elif isinstance(m, nn.Linear): |
|
|
m.weight.data.normal_(mean=0.0, std=0.01) |
|
|
m.bias.data.zero_() |
|
|
|