Thesis / utils /weight_init.py
Ryan-Pham's picture
Upload 103 files
beb7843 verified
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import math
import torch.nn as nn
def constant_init(module, val, bias=0):
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def uniform_init(module, a=0, b=1, bias=0):
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.kaiming_uniform_(module.weight,
a=a,
mode=mode,
nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(module.weight,
a=a,
mode=mode,
nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module, bias=0):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
kaiming_init(module,
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
bias=bias,
distribution='uniform')
def c2_xavier_fill(module: nn.Module):
"""
Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
Also initializes `module.bias` to 0.
Args:
module (torch.nn.Module): module to initialize.
"""
# Caffe2 implementation of XavierFill in fact
# corresponds to kaiming_uniform_ in PyTorch
nn.init.kaiming_uniform_(module.weight, a=1)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def c2_msra_fill(module: nn.Module):
"""
Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
Also initializes `module.bias` to 0.
Args:
module (torch.nn.Module): module to initialize.
"""
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def init_weights(m: nn.Module, zero_init_final_gamma=False):
"""Performs ResNet-style weight initialization."""
if isinstance(m, nn.Conv2d):
# Note that there is no bias due to BN
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
elif isinstance(m, nn.BatchNorm2d):
zero_init_gamma = (
hasattr(m, "final_bn") and m.final_bn and zero_init_final_gamma
)
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.01)
m.bias.data.zero_()