Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| """ | |
| # -------------------------------------------- | |
| # Batch Normalization | |
| # -------------------------------------------- | |
| # Kai Zhang (cskaizhang@gmail.com) | |
| # https://github.com/cszn | |
| # 01/Jan/2019 | |
| # -------------------------------------------- | |
| """ | |
| # -------------------------------------------- | |
| # remove/delete specified layer | |
| # -------------------------------------------- | |
| def deleteLayer(model, layer_type=nn.BatchNorm2d): | |
| ''' Kai Zhang, 11/Jan/2019. | |
| ''' | |
| for k, m in list(model.named_children()): | |
| if isinstance(m, layer_type): | |
| del model._modules[k] | |
| deleteLayer(m, layer_type) | |
| # -------------------------------------------- | |
| # merge bn, "conv+bn" --> "conv" | |
| # -------------------------------------------- | |
| def merge_bn(model): | |
| ''' Kai Zhang, 11/Jan/2019. | |
| merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv') | |
| based on https://github.com/pytorch/pytorch/pull/901 | |
| ''' | |
| prev_m = None | |
| for k, m in list(model.named_children()): | |
| if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)): | |
| w = prev_m.weight.data | |
| if prev_m.bias is None: | |
| zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type()) | |
| prev_m.bias = nn.Parameter(zeros) | |
| b = prev_m.bias.data | |
| invstd = m.running_var.clone().add_(m.eps).pow_(-0.5) | |
| if isinstance(prev_m, nn.ConvTranspose2d): | |
| w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w)) | |
| else: | |
| w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) | |
| b.add_(-m.running_mean).mul_(invstd) | |
| if m.affine: | |
| if isinstance(prev_m, nn.ConvTranspose2d): | |
| w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w)) | |
| else: | |
| w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) | |
| b.mul_(m.weight.data).add_(m.bias.data) | |
| del model._modules[k] | |
| prev_m = m | |
| merge_bn(m) | |
| # -------------------------------------------- | |
| # add bn, "conv" --> "conv+bn" | |
| # -------------------------------------------- | |
| def add_bn(model): | |
| ''' Kai Zhang, 11/Jan/2019. | |
| ''' | |
| for k, m in list(model.named_children()): | |
| if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)): | |
| b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True) | |
| b.weight.data.fill_(1) | |
| new_m = nn.Sequential(model._modules[k], b) | |
| model._modules[k] = new_m | |
| add_bn(m) | |
| # -------------------------------------------- | |
| # tidy model after removing bn | |
| # -------------------------------------------- | |
| def tidy_sequential(model): | |
| ''' Kai Zhang, 11/Jan/2019. | |
| ''' | |
| for k, m in list(model.named_children()): | |
| if isinstance(m, nn.Sequential): | |
| if m.__len__() == 1: | |
| model._modules[k] = m.__getitem__(0) | |
| tidy_sequential(m) | |