Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |
| from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel, | |
| MMDistributedDataParallel, is_module_wrapper) | |
| from mmcv.parallel._functions import Scatter, get_input_device, scatter | |
| from mmcv.parallel.distributed_deprecated import \ | |
| MMDistributedDataParallel as DeprecatedMMDDP | |
| from mmcv.utils import Registry | |
| def mock(*args, **kwargs): | |
| pass | |
| def test_is_module_wrapper(): | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv = nn.Conv2d(2, 2, 1) | |
| def forward(self, x): | |
| return self.conv(x) | |
| # _verify_model_across_ranks is added in torch1.9.0, | |
| # _verify_params_across_processes is added in torch1.11.0, | |
| # so we should check whether _verify_model_across_ranks | |
| # and _verify_params_across_processes are the member of | |
| # torch.distributed before mocking | |
| if hasattr(torch.distributed, '_verify_model_across_ranks'): | |
| torch.distributed._verify_model_across_ranks = mock | |
| if hasattr(torch.distributed, '_verify_params_across_processes'): | |
| torch.distributed._verify_params_across_processes = mock | |
| model = Model() | |
| assert not is_module_wrapper(model) | |
| dp = DataParallel(model) | |
| assert is_module_wrapper(dp) | |
| mmdp = MMDataParallel(model) | |
| assert is_module_wrapper(mmdp) | |
| ddp = DistributedDataParallel(model, process_group=MagicMock()) | |
| assert is_module_wrapper(ddp) | |
| mmddp = MMDistributedDataParallel(model, process_group=MagicMock()) | |
| assert is_module_wrapper(mmddp) | |
| deprecated_mmddp = DeprecatedMMDDP(model) | |
| assert is_module_wrapper(deprecated_mmddp) | |
| # test module wrapper registry | |
| class ModuleWrapper: | |
| def __init__(self, module): | |
| self.module = module | |
| def forward(self, *args, **kwargs): | |
| return self.module(*args, **kwargs) | |
| module_wraper = ModuleWrapper(model) | |
| assert is_module_wrapper(module_wraper) | |
| # test module wrapper registry in downstream repo | |
| MMRAZOR_MODULE_WRAPPERS = Registry( | |
| 'mmrazor module wrapper', parent=MODULE_WRAPPERS, scope='mmrazor') | |
| MMPOSE_MODULE_WRAPPERS = Registry( | |
| 'mmpose module wrapper', parent=MODULE_WRAPPERS, scope='mmpose') | |
| class ModuleWrapperInRazor: | |
| def __init__(self, module): | |
| self.module = module | |
| def forward(self, *args, **kwargs): | |
| return self.module(*args, **kwargs) | |
| class ModuleWrapperInPose: | |
| def __init__(self, module): | |
| self.module = module | |
| def forward(self, *args, **kwargs): | |
| return self.module(*args, **kwargs) | |
| wrapped_module = ModuleWrapperInRazor(model) | |
| assert is_module_wrapper(wrapped_module) | |
| wrapped_module = ModuleWrapperInPose(model) | |
| assert is_module_wrapper(wrapped_module) | |
| def test_get_input_device(): | |
| # if the device is CPU, return -1 | |
| input = torch.zeros([1, 3, 3, 3]) | |
| assert get_input_device(input) == -1 | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| assert get_input_device(inputs) == -1 | |
| # if the device is GPU, return the index of device | |
| if torch.cuda.is_available(): | |
| input = torch.zeros([1, 3, 3, 3]).cuda() | |
| assert get_input_device(input) == 0 | |
| inputs = [ | |
| torch.zeros([1, 3, 3, 3]).cuda(), | |
| torch.zeros([1, 4, 4, 4]).cuda() | |
| ] | |
| assert get_input_device(inputs) == 0 | |
| # input should be a tensor or list of tensor | |
| with pytest.raises(Exception): | |
| get_input_device(5) | |
| def test_scatter(): | |
| # if the device is CPU, just return the input | |
| input = torch.zeros([1, 3, 3, 3]) | |
| output = scatter(input=input, devices=[-1]) | |
| assert torch.allclose(input, output) | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| outputs = scatter(input=inputs, devices=[-1]) | |
| for input, output in zip(inputs, outputs): | |
| assert torch.allclose(input, output) | |
| # if the device is GPU, copy the input from CPU to GPU | |
| if torch.cuda.is_available(): | |
| input = torch.zeros([1, 3, 3, 3]) | |
| output = scatter(input=input, devices=[0]) | |
| assert torch.allclose(input.cuda(), output) | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| outputs = scatter(input=inputs, devices=[0]) | |
| for input, output in zip(inputs, outputs): | |
| assert torch.allclose(input.cuda(), output) | |
| # input should be a tensor or list of tensor | |
| with pytest.raises(Exception): | |
| scatter(5, [-1]) | |
| def test_Scatter(): | |
| # if the device is CPU, just return the input | |
| target_gpus = [-1] | |
| input = torch.zeros([1, 3, 3, 3]) | |
| outputs = Scatter.forward(target_gpus, input) | |
| assert isinstance(outputs, tuple) | |
| assert torch.allclose(input, outputs[0]) | |
| target_gpus = [-1] | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| outputs = Scatter.forward(target_gpus, inputs) | |
| assert isinstance(outputs, tuple) | |
| for input, output in zip(inputs, outputs): | |
| assert torch.allclose(input, output) | |
| # if the device is GPU, copy the input from CPU to GPU | |
| if torch.cuda.is_available(): | |
| target_gpus = [0] | |
| input = torch.zeros([1, 3, 3, 3]) | |
| outputs = Scatter.forward(target_gpus, input) | |
| assert isinstance(outputs, tuple) | |
| assert torch.allclose(input.cuda(), outputs[0]) | |
| target_gpus = [0] | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| outputs = Scatter.forward(target_gpus, inputs) | |
| assert isinstance(outputs, tuple) | |
| for input, output in zip(inputs, outputs): | |
| assert torch.allclose(input.cuda(), output[0]) | |