#!/usr/bin/env python3 # encoding: utf-8 # @Time : 2018/9/28 下午12:13 # @Author : yuchangqian # @Contact : changqian_yu@163.com # @File : init_func.py.py import torch.nn as nn def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs): for name, m in feature.named_modules(): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): conv_init(m.weight, **kwargs) elif isinstance(m, norm_layer): m.eps = bn_eps m.momentum = bn_momentum nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs): if isinstance(module_list, list): for feature in module_list: __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs) else: __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs) def group_weight(weight_group, module, norm_layer, lr): group_decay = [] group_no_decay = [] count = 0 for m in module.modules(): if isinstance(m, nn.Linear): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \ or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm) or isinstance(m, nn.LayerNorm): if m.weight is not None: group_no_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.Parameter): group_decay.append(m) assert len(list(module.parameters())) >= len(group_decay) + len(group_no_decay) weight_group.append(dict(params=group_decay, lr=lr)) weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) return weight_group