| import logging | |
| import torch | |
| from torch.testing._internal import common_utils | |
| from apex.transformer import parallel_state | |
| from apex.transformer.tensor_parallel import mappings | |
| from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase | |
| from apex.transformer.testing.distributed_test_base import UccDistributedTestBase | |
| logging.getLogger("torch").setLevel(logging.WARNING) | |
| logging.getLogger("apex").setLevel(logging.WARNING) | |
| class MappingTestBase: | |
| def test_reduce(self): | |
| for tensor_model_paralell_world_size in range(1, self.world_size + 1): | |
| if self.world_size % tensor_model_paralell_world_size > 0: | |
| continue | |
| parallel_state.initialize_model_parallel( | |
| tensor_model_parallel_size_=tensor_model_paralell_world_size | |
| ) | |
| t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}") | |
| expected = torch.full( | |
| (10, 10, 10, 10), | |
| 50 * tensor_model_paralell_world_size, | |
| device=f"cuda:{self.rank}", | |
| ) | |
| self.assertTrue( | |
| torch.equal(mappings._reduce(t), expected), | |
| msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", | |
| ) | |
| parallel_state.destroy_model_parallel() | |
| def test_split(self): | |
| for tensor_model_paralell_world_size in range(1, self.world_size + 1): | |
| if self.world_size % tensor_model_paralell_world_size > 0: | |
| continue | |
| parallel_state.initialize_model_parallel( | |
| tensor_model_parallel_size_=tensor_model_paralell_world_size | |
| ) | |
| tensors = [ | |
| torch.randn(10, 1) | |
| for _ in range(tensor_model_paralell_world_size) | |
| ] | |
| x = torch.cat(tensors, 1) | |
| out = mappings._split_along_last_dim(x) | |
| self.assertTrue( | |
| torch.equal( | |
| out, tensors[parallel_state.get_tensor_model_parallel_rank()] | |
| ), | |
| msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}" | |
| ) | |
| parallel_state.destroy_model_parallel() | |
| def test_gather(self): | |
| for tensor_model_paralell_world_size in range(1, self.world_size + 1): | |
| if self.world_size % tensor_model_paralell_world_size > 0: | |
| continue | |
| parallel_state.initialize_model_parallel( | |
| tensor_model_parallel_size_=tensor_model_paralell_world_size | |
| ) | |
| device = f"cuda:{self.rank}" | |
| gathered = mappings._gather_along_last_dim( | |
| torch.tensor( | |
| [parallel_state.get_tensor_model_parallel_rank()], device=device | |
| ) | |
| ) | |
| expected = torch.tensor( | |
| [rank for rank in range(tensor_model_paralell_world_size)], | |
| device=device, | |
| ) | |
| self.assertTrue( | |
| torch.equal(gathered, expected), | |
| msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", | |
| ) | |
| parallel_state.destroy_model_parallel() | |
| class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass | |
| class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass | |
| if __name__ == "__main__": | |
| common_utils.run_tests() | |