Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| import platform | |
| import numpy as np | |
| import pytest | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| if platform.system() == 'Windows': | |
| import regex as re | |
| else: | |
| import re | |
| class TestSyncBN: | |
| def dist_init(self): | |
| rank = int(os.environ['SLURM_PROCID']) | |
| world_size = int(os.environ['SLURM_NTASKS']) | |
| local_rank = int(os.environ['SLURM_LOCALID']) | |
| node_list = str(os.environ['SLURM_NODELIST']) | |
| node_parts = re.findall('[0-9]+', node_list) | |
| os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' + | |
| f'.{node_parts[3]}.{node_parts[4]}') | |
| os.environ['MASTER_PORT'] = '12341' | |
| os.environ['WORLD_SIZE'] = str(world_size) | |
| os.environ['RANK'] = str(rank) | |
| dist.init_process_group('nccl') | |
| torch.cuda.set_device(local_rank) | |
| def _test_syncbn_train(self, size=1, half=False): | |
| if 'SLURM_NTASKS' not in os.environ or int( | |
| os.environ['SLURM_NTASKS']) != 4: | |
| print('must run with slurm has 4 processes!\n' | |
| 'srun -p test --gres=gpu:4 -n4') | |
| return | |
| else: | |
| print('Running syncbn test') | |
| from mmcv.ops import SyncBatchNorm | |
| assert size in (1, 2, 4) | |
| if not dist.is_initialized(): | |
| self.dist_init() | |
| rank = dist.get_rank() | |
| torch.manual_seed(9) | |
| torch.cuda.manual_seed(9) | |
| self.x = torch.rand(16, 3, 2, 3).cuda() | |
| self.y_bp = torch.rand(16, 3, 2, 3).cuda() | |
| if half: | |
| self.x = self.x.half() | |
| self.y_bp = self.y_bp.half() | |
| dist.broadcast(self.x, src=0) | |
| dist.broadcast(self.y_bp, src=0) | |
| torch.cuda.synchronize() | |
| if size == 1: | |
| groups = [None, None, None, None] | |
| groups[0] = dist.new_group([0]) | |
| groups[1] = dist.new_group([1]) | |
| groups[2] = dist.new_group([2]) | |
| groups[3] = dist.new_group([3]) | |
| group = groups[rank] | |
| elif size == 2: | |
| groups = [None, None, None, None] | |
| groups[0] = groups[1] = dist.new_group([0, 1]) | |
| groups[2] = groups[3] = dist.new_group([2, 3]) | |
| group = groups[rank] | |
| elif size == 4: | |
| group = dist.group.WORLD | |
| syncbn = SyncBatchNorm(3, group=group).cuda() | |
| syncbn.weight.data[0] = 0.2 | |
| syncbn.weight.data[1] = 0.5 | |
| syncbn.weight.data[2] = 0.7 | |
| syncbn.train() | |
| bn = nn.BatchNorm2d(3).cuda() | |
| bn.weight.data[0] = 0.2 | |
| bn.weight.data[1] = 0.5 | |
| bn.weight.data[2] = 0.7 | |
| bn.train() | |
| sx = self.x[rank * 4:rank * 4 + 4] | |
| sx.requires_grad_() | |
| sy = syncbn(sx) | |
| sy.backward(self.y_bp[rank * 4:rank * 4 + 4]) | |
| smean = syncbn.running_mean | |
| svar = syncbn.running_var | |
| sx_grad = sx.grad | |
| sw_grad = syncbn.weight.grad | |
| sb_grad = syncbn.bias.grad | |
| if size == 1: | |
| x = self.x[rank * 4:rank * 4 + 4] | |
| y_bp = self.y_bp[rank * 4:rank * 4 + 4] | |
| elif size == 2: | |
| x = self.x[rank // 2 * 8:rank // 2 * 8 + 8] | |
| y_bp = self.y_bp[rank // 2 * 8:rank // 2 * 8 + 8] | |
| elif size == 4: | |
| x = self.x | |
| y_bp = self.y_bp | |
| x.requires_grad_() | |
| y = bn(x) | |
| y.backward(y_bp) | |
| if size == 2: | |
| y = y[rank % 2 * 4:rank % 2 * 4 + 4] | |
| elif size == 4: | |
| y = y[rank * 4:rank * 4 + 4] | |
| mean = bn.running_mean | |
| var = bn.running_var | |
| if size == 1: | |
| x_grad = x.grad | |
| w_grad = bn.weight.grad | |
| b_grad = bn.bias.grad | |
| elif size == 2: | |
| x_grad = x.grad[rank % 2 * 4:rank % 2 * 4 + 4] | |
| w_grad = bn.weight.grad / 2 | |
| b_grad = bn.bias.grad / 2 | |
| elif size == 4: | |
| x_grad = x.grad[rank * 4:rank * 4 + 4] | |
| w_grad = bn.weight.grad / 4 | |
| b_grad = bn.bias.grad / 4 | |
| assert np.allclose(mean.data.cpu().numpy(), | |
| smean.data.cpu().numpy(), 1e-3) | |
| assert np.allclose(var.data.cpu().numpy(), | |
| svar.data.cpu().numpy(), 1e-3) | |
| assert np.allclose(y.data.cpu().numpy(), sy.data.cpu().numpy(), 1e-3) | |
| assert np.allclose(w_grad.data.cpu().numpy(), | |
| sw_grad.data.cpu().numpy(), 1e-3) | |
| assert np.allclose(b_grad.data.cpu().numpy(), | |
| sb_grad.data.cpu().numpy(), 1e-3) | |
| assert np.allclose(x_grad.data.cpu().numpy(), | |
| sx_grad.data.cpu().numpy(), 1e-2) | |
| def _test_syncbn_empty_train(self, size=1, half=False): | |
| if 'SLURM_NTASKS' not in os.environ or int( | |
| os.environ['SLURM_NTASKS']) != 4: | |
| print('must run with slurm has 4 processes!\n' | |
| 'srun -p test --gres=gpu:4 -n4') | |
| return | |
| else: | |
| print('Running syncbn test') | |
| from mmcv.ops import SyncBatchNorm | |
| assert size in (1, 2, 4) | |
| if not dist.is_initialized(): | |
| self.dist_init() | |
| rank = dist.get_rank() | |
| torch.manual_seed(9) | |
| torch.cuda.manual_seed(9) | |
| self.x = torch.rand(0, 3, 2, 3).cuda() | |
| self.y_bp = torch.rand(0, 3, 2, 3).cuda() | |
| if half: | |
| self.x = self.x.half() | |
| self.y_bp = self.y_bp.half() | |
| dist.broadcast(self.x, src=0) | |
| dist.broadcast(self.y_bp, src=0) | |
| torch.cuda.synchronize() | |
| if size == 1: | |
| groups = [None, None, None, None] | |
| groups[0] = dist.new_group([0]) | |
| groups[1] = dist.new_group([1]) | |
| groups[2] = dist.new_group([2]) | |
| groups[3] = dist.new_group([3]) | |
| group = groups[rank] | |
| elif size == 2: | |
| groups = [None, None, None, None] | |
| groups[0] = groups[1] = dist.new_group([0, 1]) | |
| groups[2] = groups[3] = dist.new_group([2, 3]) | |
| group = groups[rank] | |
| elif size == 4: | |
| group = dist.group.WORLD | |
| syncbn = SyncBatchNorm(3, group=group, stats_mode='N').cuda() | |
| syncbn.weight.data[0] = 0.2 | |
| syncbn.weight.data[1] = 0.5 | |
| syncbn.weight.data[2] = 0.7 | |
| syncbn.train() | |
| bn = nn.BatchNorm2d(3).cuda() | |
| bn.weight.data[0] = 0.2 | |
| bn.weight.data[1] = 0.5 | |
| bn.weight.data[2] = 0.7 | |
| bn.train() | |
| sx = self.x[rank * 4:rank * 4 + 4] | |
| sx.requires_grad_() | |
| sy = syncbn(sx) | |
| sy.backward(self.y_bp[rank * 4:rank * 4 + 4]) | |
| smean = syncbn.running_mean | |
| svar = syncbn.running_var | |
| sx_grad = sx.grad | |
| sw_grad = syncbn.weight.grad | |
| sb_grad = syncbn.bias.grad | |
| if size == 1: | |
| x = self.x[rank * 4:rank * 4 + 4] | |
| y_bp = self.y_bp[rank * 4:rank * 4 + 4] | |
| elif size == 2: | |
| x = self.x[rank // 2 * 8:rank // 2 * 8 + 8] | |
| y_bp = self.y_bp[rank // 2 * 8:rank // 2 * 8 + 8] | |
| elif size == 4: | |
| x = self.x | |
| y_bp = self.y_bp | |
| x.requires_grad_() | |
| y = bn(x) | |
| y.backward(y_bp) | |
| if size == 2: | |
| y = y[rank % 2 * 4:rank % 2 * 4 + 4] | |
| elif size == 4: | |
| y = y[rank * 4:rank * 4 + 4] | |
| mean = bn.running_mean | |
| var = bn.running_var | |
| if size == 1: | |
| x_grad = x.grad | |
| w_grad = bn.weight.grad | |
| b_grad = bn.bias.grad | |
| elif size == 2: | |
| x_grad = x.grad[rank % 2 * 4:rank % 2 * 4 + 4] | |
| w_grad = bn.weight.grad / 2 | |
| b_grad = bn.bias.grad / 2 | |
| elif size == 4: | |
| x_grad = x.grad[rank * 4:rank * 4 + 4] | |
| w_grad = bn.weight.grad / 4 | |
| b_grad = bn.bias.grad / 4 | |
| assert np.allclose(mean.data.cpu().numpy(), | |
| smean.data.cpu().numpy(), 1e-3) | |
| assert np.allclose(var.data.cpu().numpy(), | |
| svar.data.cpu().numpy(), 1e-3) | |
| assert np.allclose(y.data.cpu().numpy(), sy.data.cpu().numpy(), 1e-3) | |
| assert np.allclose(w_grad.data.cpu().numpy(), | |
| sw_grad.data.cpu().numpy(), 1e-3) | |
| assert np.allclose(b_grad.data.cpu().numpy(), | |
| sb_grad.data.cpu().numpy(), 1e-3) | |
| assert np.allclose(x_grad.data.cpu().numpy(), | |
| sx_grad.data.cpu().numpy(), 1e-2) | |
| # 'stats_mode' only allows 'default' and 'N' | |
| with pytest.raises(AssertionError): | |
| SyncBatchNorm(3, group=group, stats_mode='X') | |
| def test_syncbn_1(self): | |
| self._test_syncbn_train(size=1) | |
| def test_syncbn_2(self): | |
| self._test_syncbn_train(size=2) | |
| def test_syncbn_4(self): | |
| self._test_syncbn_train(size=4) | |
| def test_syncbn_1_half(self): | |
| self._test_syncbn_train(size=1, half=True) | |
| def test_syncbn_2_half(self): | |
| self._test_syncbn_train(size=2, half=True) | |
| def test_syncbn_4_half(self): | |
| self._test_syncbn_train(size=4, half=True) | |
| def test_syncbn_empty_1(self): | |
| self._test_syncbn_empty_train(size=1) | |
| def test_syncbn_empty_2(self): | |
| self._test_syncbn_empty_train(size=2) | |
| def test_syncbn_empty_4(self): | |
| self._test_syncbn_empty_train(size=4) | |
| def test_syncbn_empty_1_half(self): | |
| self._test_syncbn_empty_train(size=1, half=True) | |
| def test_syncbn_empty_2_half(self): | |
| self._test_syncbn_empty_train(size=2, half=True) | |
| def test_syncbn_empty_4_half(self): | |
| self._test_syncbn_empty_train(size=4, half=True) | |