Spaces:
Runtime error
Runtime error
Upload utils/sync_batchnorm/batchnorm.py
Browse files
utils/sync_batchnorm/batchnorm.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : batchnorm.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : maojiayuan@gmail.com
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import collections
|
| 12 |
+
import contextlib
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
| 21 |
+
except ImportError:
|
| 22 |
+
ReduceAddCoalesced = Broadcast = None
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from jactorch.parallel.comm import SyncMaster
|
| 26 |
+
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
|
| 27 |
+
except ImportError:
|
| 28 |
+
from .comm import SyncMaster
|
| 29 |
+
from .replicate import DataParallelWithCallback
|
| 30 |
+
|
| 31 |
+
__all__ = [
|
| 32 |
+
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
|
| 33 |
+
'patch_sync_batchnorm', 'convert_model'
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _sum_ft(tensor):
|
| 38 |
+
"""sum over the first and last dimention"""
|
| 39 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _unsqueeze_ft(tensor):
|
| 43 |
+
"""add new dimensions at the front and the tail"""
|
| 44 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
| 48 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
| 52 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
| 53 |
+
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
|
| 54 |
+
|
| 55 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
| 56 |
+
|
| 57 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
| 58 |
+
|
| 59 |
+
self._is_parallel = False
|
| 60 |
+
self._parallel_id = None
|
| 61 |
+
self._slave_pipe = None
|
| 62 |
+
|
| 63 |
+
def forward(self, input):
|
| 64 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
| 65 |
+
if not (self._is_parallel and self.training):
|
| 66 |
+
return F.batch_norm(
|
| 67 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
| 68 |
+
self.training, self.momentum, self.eps)
|
| 69 |
+
|
| 70 |
+
# Resize the input to (B, C, -1).
|
| 71 |
+
input_shape = input.size()
|
| 72 |
+
input = input.view(input.size(0), self.num_features, -1)
|
| 73 |
+
|
| 74 |
+
# Compute the sum and square-sum.
|
| 75 |
+
sum_size = input.size(0) * input.size(2)
|
| 76 |
+
input_sum = _sum_ft(input)
|
| 77 |
+
input_ssum = _sum_ft(input ** 2)
|
| 78 |
+
|
| 79 |
+
# Reduce-and-broadcast the statistics.
|
| 80 |
+
if self._parallel_id == 0:
|
| 81 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
| 82 |
+
else:
|
| 83 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
| 84 |
+
|
| 85 |
+
# Compute the output.
|
| 86 |
+
if self.affine:
|
| 87 |
+
# MJY:: Fuse the multiplication for speed.
|
| 88 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
| 89 |
+
else:
|
| 90 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
| 91 |
+
|
| 92 |
+
# Reshape it.
|
| 93 |
+
return output.view(input_shape)
|
| 94 |
+
|
| 95 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
| 96 |
+
self._is_parallel = True
|
| 97 |
+
self._parallel_id = copy_id
|
| 98 |
+
|
| 99 |
+
# parallel_id == 0 means master device.
|
| 100 |
+
if self._parallel_id == 0:
|
| 101 |
+
ctx.sync_master = self._sync_master
|
| 102 |
+
else:
|
| 103 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
| 104 |
+
|
| 105 |
+
def _data_parallel_master(self, intermediates):
|
| 106 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
| 107 |
+
|
| 108 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
| 109 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
| 110 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
| 111 |
+
|
| 112 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
| 113 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
| 114 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
| 115 |
+
|
| 116 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
| 117 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
| 118 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
| 119 |
+
|
| 120 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
| 121 |
+
|
| 122 |
+
outputs = []
|
| 123 |
+
for i, rec in enumerate(intermediates):
|
| 124 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
| 125 |
+
|
| 126 |
+
return outputs
|
| 127 |
+
|
| 128 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
| 129 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
| 130 |
+
also maintains the moving average on the master device."""
|
| 131 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
| 132 |
+
mean = sum_ / size
|
| 133 |
+
sumvar = ssum - sum_ * mean
|
| 134 |
+
unbias_var = sumvar / (size - 1)
|
| 135 |
+
bias_var = sumvar / size
|
| 136 |
+
|
| 137 |
+
if hasattr(torch, 'no_grad'):
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
| 140 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
| 141 |
+
else:
|
| 142 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
| 143 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
| 144 |
+
|
| 145 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
| 149 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
| 150 |
+
mini-batch.
|
| 151 |
+
|
| 152 |
+
.. math::
|
| 153 |
+
|
| 154 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 155 |
+
|
| 156 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
| 157 |
+
standard-deviation are reduced across all devices during training.
|
| 158 |
+
|
| 159 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 160 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 161 |
+
the statistics only on that device, which accelerated the computation and
|
| 162 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 163 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 164 |
+
over all training samples distributed on multiple devices.
|
| 165 |
+
|
| 166 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 167 |
+
as the built-in PyTorch implementation.
|
| 168 |
+
|
| 169 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 170 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 171 |
+
of size C (where C is the input size).
|
| 172 |
+
|
| 173 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 174 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 175 |
+
|
| 176 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 177 |
+
|
| 178 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 179 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
num_features: num_features from an expected input of size
|
| 183 |
+
`batch_size x num_features [x width]`
|
| 184 |
+
eps: a value added to the denominator for numerical stability.
|
| 185 |
+
Default: 1e-5
|
| 186 |
+
momentum: the value used for the running_mean and running_var
|
| 187 |
+
computation. Default: 0.1
|
| 188 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 189 |
+
affine parameters. Default: ``True``
|
| 190 |
+
|
| 191 |
+
Shape::
|
| 192 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
| 193 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
| 194 |
+
|
| 195 |
+
Examples:
|
| 196 |
+
>>> # With Learnable Parameters
|
| 197 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
| 198 |
+
>>> # Without Learnable Parameters
|
| 199 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
| 200 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
| 201 |
+
>>> output = m(input)
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
def _check_input_dim(self, input):
|
| 205 |
+
if input.dim() != 2 and input.dim() != 3:
|
| 206 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
| 207 |
+
.format(input.dim()))
|
| 208 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
| 212 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
| 213 |
+
of 3d inputs
|
| 214 |
+
|
| 215 |
+
.. math::
|
| 216 |
+
|
| 217 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 218 |
+
|
| 219 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
| 220 |
+
standard-deviation are reduced across all devices during training.
|
| 221 |
+
|
| 222 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 223 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 224 |
+
the statistics only on that device, which accelerated the computation and
|
| 225 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 226 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 227 |
+
over all training samples distributed on multiple devices.
|
| 228 |
+
|
| 229 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 230 |
+
as the built-in PyTorch implementation.
|
| 231 |
+
|
| 232 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 233 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 234 |
+
of size C (where C is the input size).
|
| 235 |
+
|
| 236 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 237 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 238 |
+
|
| 239 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 240 |
+
|
| 241 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 242 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
num_features: num_features from an expected input of
|
| 246 |
+
size batch_size x num_features x height x width
|
| 247 |
+
eps: a value added to the denominator for numerical stability.
|
| 248 |
+
Default: 1e-5
|
| 249 |
+
momentum: the value used for the running_mean and running_var
|
| 250 |
+
computation. Default: 0.1
|
| 251 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 252 |
+
affine parameters. Default: ``True``
|
| 253 |
+
|
| 254 |
+
Shape::
|
| 255 |
+
- Input: :math:`(N, C, H, W)`
|
| 256 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
| 257 |
+
|
| 258 |
+
Examples:
|
| 259 |
+
>>> # With Learnable Parameters
|
| 260 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
| 261 |
+
>>> # Without Learnable Parameters
|
| 262 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
| 263 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
| 264 |
+
>>> output = m(input)
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def _check_input_dim(self, input):
|
| 268 |
+
if input.dim() != 4:
|
| 269 |
+
raise ValueError('expected 4D input (got {}D input)'
|
| 270 |
+
.format(input.dim()))
|
| 271 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
| 275 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
| 276 |
+
of 4d inputs
|
| 277 |
+
|
| 278 |
+
.. math::
|
| 279 |
+
|
| 280 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 281 |
+
|
| 282 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
| 283 |
+
standard-deviation are reduced across all devices during training.
|
| 284 |
+
|
| 285 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 286 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 287 |
+
the statistics only on that device, which accelerated the computation and
|
| 288 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 289 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 290 |
+
over all training samples distributed on multiple devices.
|
| 291 |
+
|
| 292 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 293 |
+
as the built-in PyTorch implementation.
|
| 294 |
+
|
| 295 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 296 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 297 |
+
of size C (where C is the input size).
|
| 298 |
+
|
| 299 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 300 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 301 |
+
|
| 302 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 303 |
+
|
| 304 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 305 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
| 306 |
+
or Spatio-temporal BatchNorm
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
num_features: num_features from an expected input of
|
| 310 |
+
size batch_size x num_features x depth x height x width
|
| 311 |
+
eps: a value added to the denominator for numerical stability.
|
| 312 |
+
Default: 1e-5
|
| 313 |
+
momentum: the value used for the running_mean and running_var
|
| 314 |
+
computation. Default: 0.1
|
| 315 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 316 |
+
affine parameters. Default: ``True``
|
| 317 |
+
|
| 318 |
+
Shape::
|
| 319 |
+
- Input: :math:`(N, C, D, H, W)`
|
| 320 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
| 321 |
+
|
| 322 |
+
Examples:
|
| 323 |
+
>>> # With Learnable Parameters
|
| 324 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
| 325 |
+
>>> # Without Learnable Parameters
|
| 326 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
| 327 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
| 328 |
+
>>> output = m(input)
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def _check_input_dim(self, input):
|
| 332 |
+
if input.dim() != 5:
|
| 333 |
+
raise ValueError('expected 5D input (got {}D input)'
|
| 334 |
+
.format(input.dim()))
|
| 335 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@contextlib.contextmanager
|
| 339 |
+
def patch_sync_batchnorm():
|
| 340 |
+
import torch.nn as nn
|
| 341 |
+
|
| 342 |
+
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
|
| 343 |
+
|
| 344 |
+
nn.BatchNorm1d = SynchronizedBatchNorm1d
|
| 345 |
+
nn.BatchNorm2d = SynchronizedBatchNorm2d
|
| 346 |
+
nn.BatchNorm3d = SynchronizedBatchNorm3d
|
| 347 |
+
|
| 348 |
+
yield
|
| 349 |
+
|
| 350 |
+
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def convert_model(module):
|
| 354 |
+
"""Traverse the input module and its child recursively
|
| 355 |
+
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
|
| 356 |
+
to SynchronizedBatchNorm*N*d
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
module: the input module needs to be convert to SyncBN model
|
| 360 |
+
|
| 361 |
+
Examples:
|
| 362 |
+
>>> import torch.nn as nn
|
| 363 |
+
>>> import torchvision
|
| 364 |
+
>>> # m is a standard pytorch model
|
| 365 |
+
>>> m = torchvision.models.resnet18(True)
|
| 366 |
+
>>> m = nn.DataParallel(m)
|
| 367 |
+
>>> # after convert, m is using SyncBN
|
| 368 |
+
>>> m = convert_model(m)
|
| 369 |
+
"""
|
| 370 |
+
if isinstance(module, torch.nn.DataParallel):
|
| 371 |
+
mod = module.module
|
| 372 |
+
mod = convert_model(mod)
|
| 373 |
+
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
|
| 374 |
+
return mod
|
| 375 |
+
|
| 376 |
+
mod = module
|
| 377 |
+
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
|
| 378 |
+
torch.nn.modules.batchnorm.BatchNorm2d,
|
| 379 |
+
torch.nn.modules.batchnorm.BatchNorm3d],
|
| 380 |
+
[SynchronizedBatchNorm1d,
|
| 381 |
+
SynchronizedBatchNorm2d,
|
| 382 |
+
SynchronizedBatchNorm3d]):
|
| 383 |
+
if isinstance(module, pth_module):
|
| 384 |
+
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
|
| 385 |
+
mod.running_mean = module.running_mean
|
| 386 |
+
mod.running_var = module.running_var
|
| 387 |
+
if module.affine:
|
| 388 |
+
mod.weight.data = module.weight.data.clone().detach()
|
| 389 |
+
mod.bias.data = module.bias.data.clone().detach()
|
| 390 |
+
|
| 391 |
+
for name, child in module.named_children():
|
| 392 |
+
mod.add_module(name, convert_model(child))
|
| 393 |
+
|
| 394 |
+
return mod
|