| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import collections |
| | import contextlib |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from torch.nn.modules.batchnorm import _BatchNorm |
| |
|
| | try: |
| | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast |
| | except ImportError: |
| | ReduceAddCoalesced = Broadcast = None |
| |
|
| | try: |
| | from jactorch.parallel.comm import SyncMaster |
| | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback |
| | except ImportError: |
| | from .comm import SyncMaster |
| | from .replicate import DataParallelWithCallback |
| |
|
| | __all__ = [ |
| | 'set_sbn_eps_mode', |
| | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', |
| | 'patch_sync_batchnorm', 'convert_model' |
| | ] |
| |
|
| |
|
| | SBN_EPS_MODE = 'clamp' |
| |
|
| |
|
| | def set_sbn_eps_mode(mode): |
| | global SBN_EPS_MODE |
| | assert mode in ('clamp', 'plus') |
| | SBN_EPS_MODE = mode |
| |
|
| |
|
| | def _sum_ft(tensor): |
| | """sum over the first and last dimention""" |
| | return tensor.sum(dim=0).sum(dim=-1) |
| |
|
| |
|
| | def _unsqueeze_ft(tensor): |
| | """add new dimensions at the front and the tail""" |
| | return tensor.unsqueeze(0).unsqueeze(-1) |
| |
|
| |
|
| | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) |
| | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) |
| |
|
| |
|
| | class _SynchronizedBatchNorm(_BatchNorm): |
| | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): |
| | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' |
| |
|
| | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, |
| | track_running_stats=track_running_stats) |
| |
|
| | if not self.track_running_stats: |
| | import warnings |
| | warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') |
| |
|
| | self._sync_master = SyncMaster(self._data_parallel_master) |
| |
|
| | self._is_parallel = False |
| | self._parallel_id = None |
| | self._slave_pipe = None |
| |
|
| | def forward(self, input): |
| | |
| | if not (self._is_parallel and self.training): |
| | return F.batch_norm( |
| | input, self.running_mean, self.running_var, self.weight, self.bias, |
| | self.training, self.momentum, self.eps) |
| |
|
| | |
| | input_shape = input.size() |
| | assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) |
| | input = input.view(input.size(0), self.num_features, -1) |
| |
|
| | |
| | sum_size = input.size(0) * input.size(2) |
| | input_sum = _sum_ft(input) |
| | input_ssum = _sum_ft(input ** 2) |
| |
|
| | |
| | if self._parallel_id == 0: |
| | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) |
| | else: |
| | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) |
| |
|
| | |
| | if self.affine: |
| | |
| | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) |
| | else: |
| | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) |
| |
|
| | |
| | return output.view(input_shape) |
| |
|
| | def __data_parallel_replicate__(self, ctx, copy_id): |
| | self._is_parallel = True |
| | self._parallel_id = copy_id |
| |
|
| | |
| | if self._parallel_id == 0: |
| | ctx.sync_master = self._sync_master |
| | else: |
| | self._slave_pipe = ctx.sync_master.register_slave(copy_id) |
| |
|
| | def _data_parallel_master(self, intermediates): |
| | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" |
| |
|
| | |
| | |
| | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) |
| |
|
| | to_reduce = [i[1][:2] for i in intermediates] |
| | to_reduce = [j for i in to_reduce for j in i] |
| | target_gpus = [i[1].sum.get_device() for i in intermediates] |
| |
|
| | sum_size = sum([i[1].sum_size for i in intermediates]) |
| | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) |
| | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) |
| |
|
| | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) |
| |
|
| | outputs = [] |
| | for i, rec in enumerate(intermediates): |
| | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) |
| |
|
| | return outputs |
| |
|
| | def _compute_mean_std(self, sum_, ssum, size): |
| | """Compute the mean and standard-deviation with sum and square-sum. This method |
| | also maintains the moving average on the master device.""" |
| | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' |
| | mean = sum_ / size |
| | sumvar = ssum - sum_ * mean |
| | unbias_var = sumvar / (size - 1) |
| | bias_var = sumvar / size |
| |
|
| | if hasattr(torch, 'no_grad'): |
| | with torch.no_grad(): |
| | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data |
| | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data |
| | else: |
| | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data |
| | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data |
| |
|
| | if SBN_EPS_MODE == 'clamp': |
| | return mean, bias_var.clamp(self.eps) ** -0.5 |
| | elif SBN_EPS_MODE == 'plus': |
| | return mean, (bias_var + self.eps) ** -0.5 |
| | else: |
| | raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) |
| |
|
| |
|
| | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): |
| | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a |
| | mini-batch. |
| | |
| | .. math:: |
| | |
| | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta |
| | |
| | This module differs from the built-in PyTorch BatchNorm1d as the mean and |
| | standard-deviation are reduced across all devices during training. |
| | |
| | For example, when one uses `nn.DataParallel` to wrap the network during |
| | training, PyTorch's implementation normalize the tensor on each device using |
| | the statistics only on that device, which accelerated the computation and |
| | is also easy to implement, but the statistics might be inaccurate. |
| | Instead, in this synchronized version, the statistics will be computed |
| | over all training samples distributed on multiple devices. |
| | |
| | Note that, for one-GPU or CPU-only case, this module behaves exactly same |
| | as the built-in PyTorch implementation. |
| | |
| | The mean and standard-deviation are calculated per-dimension over |
| | the mini-batches and gamma and beta are learnable parameter vectors |
| | of size C (where C is the input size). |
| | |
| | During training, this layer keeps a running estimate of its computed mean |
| | and variance. The running sum is kept with a default momentum of 0.1. |
| | |
| | During evaluation, this running mean/variance is used for normalization. |
| | |
| | Because the BatchNorm is done over the `C` dimension, computing statistics |
| | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm |
| | |
| | Args: |
| | num_features: num_features from an expected input of size |
| | `batch_size x num_features [x width]` |
| | eps: a value added to the denominator for numerical stability. |
| | Default: 1e-5 |
| | momentum: the value used for the running_mean and running_var |
| | computation. Default: 0.1 |
| | affine: a boolean value that when set to ``True``, gives the layer learnable |
| | affine parameters. Default: ``True`` |
| | |
| | Shape:: |
| | - Input: :math:`(N, C)` or :math:`(N, C, L)` |
| | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) |
| | |
| | Examples: |
| | >>> # With Learnable Parameters |
| | >>> m = SynchronizedBatchNorm1d(100) |
| | >>> # Without Learnable Parameters |
| | >>> m = SynchronizedBatchNorm1d(100, affine=False) |
| | >>> input = torch.autograd.Variable(torch.randn(20, 100)) |
| | >>> output = m(input) |
| | """ |
| |
|
| | def _check_input_dim(self, input): |
| | if input.dim() != 2 and input.dim() != 3: |
| | raise ValueError('expected 2D or 3D input (got {}D input)' |
| | .format(input.dim())) |
| |
|
| |
|
| | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): |
| | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch |
| | of 3d inputs |
| | |
| | .. math:: |
| | |
| | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta |
| | |
| | This module differs from the built-in PyTorch BatchNorm2d as the mean and |
| | standard-deviation are reduced across all devices during training. |
| | |
| | For example, when one uses `nn.DataParallel` to wrap the network during |
| | training, PyTorch's implementation normalize the tensor on each device using |
| | the statistics only on that device, which accelerated the computation and |
| | is also easy to implement, but the statistics might be inaccurate. |
| | Instead, in this synchronized version, the statistics will be computed |
| | over all training samples distributed on multiple devices. |
| | |
| | Note that, for one-GPU or CPU-only case, this module behaves exactly same |
| | as the built-in PyTorch implementation. |
| | |
| | The mean and standard-deviation are calculated per-dimension over |
| | the mini-batches and gamma and beta are learnable parameter vectors |
| | of size C (where C is the input size). |
| | |
| | During training, this layer keeps a running estimate of its computed mean |
| | and variance. The running sum is kept with a default momentum of 0.1. |
| | |
| | During evaluation, this running mean/variance is used for normalization. |
| | |
| | Because the BatchNorm is done over the `C` dimension, computing statistics |
| | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm |
| | |
| | Args: |
| | num_features: num_features from an expected input of |
| | size batch_size x num_features x height x width |
| | eps: a value added to the denominator for numerical stability. |
| | Default: 1e-5 |
| | momentum: the value used for the running_mean and running_var |
| | computation. Default: 0.1 |
| | affine: a boolean value that when set to ``True``, gives the layer learnable |
| | affine parameters. Default: ``True`` |
| | |
| | Shape:: |
| | - Input: :math:`(N, C, H, W)` |
| | - Output: :math:`(N, C, H, W)` (same shape as input) |
| | |
| | Examples: |
| | >>> # With Learnable Parameters |
| | >>> m = SynchronizedBatchNorm2d(100) |
| | >>> # Without Learnable Parameters |
| | >>> m = SynchronizedBatchNorm2d(100, affine=False) |
| | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) |
| | >>> output = m(input) |
| | """ |
| |
|
| | def _check_input_dim(self, input): |
| | if input.dim() != 4: |
| | raise ValueError('expected 4D input (got {}D input)' |
| | .format(input.dim())) |
| |
|
| |
|
| | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): |
| | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch |
| | of 4d inputs |
| | |
| | .. math:: |
| | |
| | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta |
| | |
| | This module differs from the built-in PyTorch BatchNorm3d as the mean and |
| | standard-deviation are reduced across all devices during training. |
| | |
| | For example, when one uses `nn.DataParallel` to wrap the network during |
| | training, PyTorch's implementation normalize the tensor on each device using |
| | the statistics only on that device, which accelerated the computation and |
| | is also easy to implement, but the statistics might be inaccurate. |
| | Instead, in this synchronized version, the statistics will be computed |
| | over all training samples distributed on multiple devices. |
| | |
| | Note that, for one-GPU or CPU-only case, this module behaves exactly same |
| | as the built-in PyTorch implementation. |
| | |
| | The mean and standard-deviation are calculated per-dimension over |
| | the mini-batches and gamma and beta are learnable parameter vectors |
| | of size C (where C is the input size). |
| | |
| | During training, this layer keeps a running estimate of its computed mean |
| | and variance. The running sum is kept with a default momentum of 0.1. |
| | |
| | During evaluation, this running mean/variance is used for normalization. |
| | |
| | Because the BatchNorm is done over the `C` dimension, computing statistics |
| | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm |
| | or Spatio-temporal BatchNorm |
| | |
| | Args: |
| | num_features: num_features from an expected input of |
| | size batch_size x num_features x depth x height x width |
| | eps: a value added to the denominator for numerical stability. |
| | Default: 1e-5 |
| | momentum: the value used for the running_mean and running_var |
| | computation. Default: 0.1 |
| | affine: a boolean value that when set to ``True``, gives the layer learnable |
| | affine parameters. Default: ``True`` |
| | |
| | Shape:: |
| | - Input: :math:`(N, C, D, H, W)` |
| | - Output: :math:`(N, C, D, H, W)` (same shape as input) |
| | |
| | Examples: |
| | >>> # With Learnable Parameters |
| | >>> m = SynchronizedBatchNorm3d(100) |
| | >>> # Without Learnable Parameters |
| | >>> m = SynchronizedBatchNorm3d(100, affine=False) |
| | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) |
| | >>> output = m(input) |
| | """ |
| |
|
| | def _check_input_dim(self, input): |
| | if input.dim() != 5: |
| | raise ValueError('expected 5D input (got {}D input)' |
| | .format(input.dim())) |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def patch_sync_batchnorm(): |
| | import torch.nn as nn |
| |
|
| | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d |
| |
|
| | nn.BatchNorm1d = SynchronizedBatchNorm1d |
| | nn.BatchNorm2d = SynchronizedBatchNorm2d |
| | nn.BatchNorm3d = SynchronizedBatchNorm3d |
| |
|
| | yield |
| |
|
| | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup |
| |
|
| |
|
| | def convert_model(module): |
| | """Traverse the input module and its child recursively |
| | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d |
| | to SynchronizedBatchNorm*N*d |
| | |
| | Args: |
| | module: the input module needs to be convert to SyncBN model |
| | |
| | Examples: |
| | >>> import torch.nn as nn |
| | >>> import torchvision |
| | >>> # m is a standard pytorch model |
| | >>> m = torchvision.models.resnet18(True) |
| | >>> m = nn.DataParallel(m) |
| | >>> # after convert, m is using SyncBN |
| | >>> m = convert_model(m) |
| | """ |
| | if isinstance(module, torch.nn.DataParallel): |
| | mod = module.module |
| | mod = convert_model(mod) |
| | mod = DataParallelWithCallback(mod, device_ids=module.device_ids) |
| | return mod |
| |
|
| | mod = module |
| | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, |
| | torch.nn.modules.batchnorm.BatchNorm2d, |
| | torch.nn.modules.batchnorm.BatchNorm3d], |
| | [SynchronizedBatchNorm1d, |
| | SynchronizedBatchNorm2d, |
| | SynchronizedBatchNorm3d]): |
| | if isinstance(module, pth_module): |
| | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) |
| | mod.running_mean = module.running_mean |
| | mod.running_var = module.running_var |
| | if module.affine: |
| | mod.weight.data = module.weight.data.clone().detach() |
| | mod.bias.data = module.bias.data.clone().detach() |
| |
|
| | for name, child in module.named_children(): |
| | mod.add_module(name, convert_model(child)) |
| |
|
| | return mod |
| |
|