| | import logging |
| |
|
| | import torch.testing |
| | from torch.testing._internal import common_utils |
| |
|
| | logging.getLogger("torch").setLevel(logging.WARNING) |
| |
|
| | from apex.transformer import parallel_state |
| | from apex.transformer.tensor_parallel import data as data_utils |
| | from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase |
| | from apex.transformer.testing.distributed_test_base import UccDistributedTestBase |
| |
|
| | logging.getLogger("torch").setLevel(logging.WARNING) |
| |
|
| |
|
| | class BroadcastDataTestBase: |
| | def test_broadcast_data(self): |
| | tensor_model_parallel_world_size: int = self.world_size // ( |
| | 1 + self.world_size > 1 |
| | ) |
| | parallel_state.initialize_model_parallel( |
| | tensor_model_parallel_size_=tensor_model_parallel_world_size |
| | ) |
| |
|
| | target_key_size = { |
| | "key1": [7, 11], |
| | "key2": [8, 2, 1], |
| | "key3": [13], |
| | "key4": [5, 1, 2], |
| | "key5": [5, 12], |
| | } |
| | keys = [k for k in target_key_size] |
| |
|
| | data = {} |
| | data_t = {} |
| | with torch.no_grad(): |
| | for key in target_key_size: |
| | data[key] = torch.randint(0, 1000, size=target_key_size[key]) |
| | data_t[key] = data[key].clone() |
| | |
| | data["key_x"] = torch.rand(5) |
| | data_t["key_x"] = data["key_x"].clone() |
| | if parallel_state.get_tensor_model_parallel_rank() != 0: |
| | data = None |
| |
|
| | data_utils._check_data_types(keys, data_t, torch.int64) |
| | key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data) |
| |
|
| | for key in keys: |
| | self.assertEqual(target_key_size[key], key_size[key]) |
| |
|
| | broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64) |
| | for key in keys: |
| | self.assertEqual(broadcasted_data[key], data_t[key].cuda()) |
| |
|
| | parallel_state.destroy_model_parallel() |
| |
|
| |
|
| | class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): pass |
| | class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): pass |
| |
|
| |
|
| | if __name__ == "__main__": |
| | common_utils.run_tests() |
| |
|