| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.autograd import Variable |
| |
|
| | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback |
| | from sync_batchnorm.unittest import TorchTestCase |
| |
|
| |
|
| | def handy_var(a, unbias=True): |
| | n = a.size(0) |
| | asum = a.sum(dim=0) |
| | as_sum = (a ** 2).sum(dim=0) |
| | sumvar = as_sum - asum * asum / n |
| | if unbias: |
| | return sumvar / (n - 1) |
| | else: |
| | return sumvar / n |
| |
|
| |
|
| | def _find_bn(module): |
| | for m in module.modules(): |
| | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): |
| | return m |
| |
|
| |
|
| | class SyncTestCase(TorchTestCase): |
| | def _syncParameters(self, bn1, bn2): |
| | bn1.reset_parameters() |
| | bn2.reset_parameters() |
| | if bn1.affine and bn2.affine: |
| | bn2.weight.data.copy_(bn1.weight.data) |
| | bn2.bias.data.copy_(bn1.bias.data) |
| |
|
| | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): |
| | """Check the forward and backward for the customized batch normalization.""" |
| | bn1.train(mode=is_train) |
| | bn2.train(mode=is_train) |
| |
|
| | if cuda: |
| | input = input.cuda() |
| |
|
| | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) |
| |
|
| | input1 = Variable(input, requires_grad=True) |
| | output1 = bn1(input1) |
| | output1.sum().backward() |
| | input2 = Variable(input, requires_grad=True) |
| | output2 = bn2(input2) |
| | output2.sum().backward() |
| |
|
| | self.assertTensorClose(input1.data, input2.data) |
| | self.assertTensorClose(output1.data, output2.data) |
| | self.assertTensorClose(input1.grad, input2.grad) |
| | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) |
| | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) |
| |
|
| | def testSyncBatchNormNormalTrain(self): |
| | bn = nn.BatchNorm1d(10) |
| | sync_bn = SynchronizedBatchNorm1d(10) |
| |
|
| | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) |
| |
|
| | def testSyncBatchNormNormalEval(self): |
| | bn = nn.BatchNorm1d(10) |
| | sync_bn = SynchronizedBatchNorm1d(10) |
| |
|
| | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) |
| |
|
| | def testSyncBatchNormSyncTrain(self): |
| | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) |
| | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) |
| | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) |
| |
|
| | bn.cuda() |
| | sync_bn.cuda() |
| |
|
| | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) |
| |
|
| | def testSyncBatchNormSyncEval(self): |
| | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) |
| | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) |
| | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) |
| |
|
| | bn.cuda() |
| | sync_bn.cuda() |
| |
|
| | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) |
| |
|
| | def testSyncBatchNorm2DSyncTrain(self): |
| | bn = nn.BatchNorm2d(10) |
| | sync_bn = SynchronizedBatchNorm2d(10) |
| | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) |
| |
|
| | bn.cuda() |
| | sync_bn.cuda() |
| |
|
| | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | unittest.main() |
| |
|