| | """ Split BatchNorm |
| | |
| | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through |
| | a separate BN layer. The first split is passed through the parent BN layers with weight/bias |
| | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' |
| | namespace. |
| | |
| | This allows easily removing the auxiliary BN layers after training to efficiently |
| | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, |
| | 'Disentangled Learning via An Auxiliary BN' |
| | |
| | Hacked together by / Copyright 2020 Ross Wightman |
| | """ |
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class SplitBatchNorm2d(torch.nn.BatchNorm2d): |
| |
|
| | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, |
| | track_running_stats=True, num_splits=2): |
| | super().__init__(num_features, eps, momentum, affine, track_running_stats) |
| | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' |
| | self.num_splits = num_splits |
| | self.aux_bn = nn.ModuleList([ |
| | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) |
| |
|
| | def forward(self, input: torch.Tensor): |
| | if self.training: |
| | split_size = input.shape[0] // self.num_splits |
| | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" |
| | split_input = input.split(split_size) |
| | x = [super().forward(split_input[0])] |
| | for i, a in enumerate(self.aux_bn): |
| | x.append(a(split_input[i + 1])) |
| | return torch.cat(x, dim=0) |
| | else: |
| | return super().forward(input) |
| |
|
| |
|
| | def convert_splitbn_model(module, num_splits=2): |
| | """ |
| | Recursively traverse module and its children to replace all instances of |
| | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. |
| | Args: |
| | module (torch.nn.Module): input module |
| | num_splits: number of separate batchnorm layers to split input across |
| | Example:: |
| | >>> # model is an instance of torch.nn.Module |
| | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) |
| | """ |
| | mod = module |
| | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): |
| | return module |
| | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): |
| | mod = SplitBatchNorm2d( |
| | module.num_features, module.eps, module.momentum, module.affine, |
| | module.track_running_stats, num_splits=num_splits) |
| | mod.running_mean = module.running_mean |
| | mod.running_var = module.running_var |
| | mod.num_batches_tracked = module.num_batches_tracked |
| | if module.affine: |
| | mod.weight.data = module.weight.data.clone().detach() |
| | mod.bias.data = module.bias.data.clone().detach() |
| | for aux in mod.aux_bn: |
| | aux.running_mean = module.running_mean.clone() |
| | aux.running_var = module.running_var.clone() |
| | aux.num_batches_tracked = module.num_batches_tracked.clone() |
| | if module.affine: |
| | aux.weight.data = module.weight.data.clone().detach() |
| | aux.bias.data = module.bias.data.clone().detach() |
| | for name, child in module.named_children(): |
| | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) |
| | del module |
| | return mod |
| |
|