File size: 2,273 Bytes
5f0437a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
#!/usr/bin/env python3
# encoding: utf-8
# @Time : 2018/9/28 下午12:13
# @Author : yuchangqian
# @Contact : changqian_yu@163.com
# @File : init_func.py.py
import torch.nn as nn
def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
**kwargs):
for name, m in feature.named_modules():
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
conv_init(m.weight, **kwargs)
elif isinstance(m, norm_layer):
m.eps = bn_eps
m.momentum = bn_momentum
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
**kwargs):
if isinstance(module_list, list):
for feature in module_list:
__init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
**kwargs)
else:
__init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
**kwargs)
def group_weight(weight_group, module, norm_layer, lr):
group_decay = []
group_no_decay = []
count = 0
for m in module.modules():
if isinstance(m, nn.Linear):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \
or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm) or isinstance(m, nn.LayerNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.Parameter):
group_decay.append(m)
assert len(list(module.parameters())) >= len(group_decay) + len(group_no_decay)
weight_group.append(dict(params=group_decay, lr=lr))
weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
return weight_group |