peacock-data-public-evaluation / Megatron-DeepSpeed /tests /unit_tests /tensor_parallel /test_mappings.py
| from megatron.core.tensor_parallel import mappings | |
| from tests.unit_tests.test_utilities import Utils | |
| import torch | |
| def test_CopyToModelParallelRegion(): | |
| Utils.initialize_model_parallel(4,2) | |
| input_data = torch.ones((1)).cuda()*Utils.rank | |
| output_data = mappings._CopyToModelParallelRegion.backward(None, input_data) | |
| result = torch.ones(1).cuda() | |
| result = result * 22 if Utils.rank >= 4 else result * 6 | |
| assert(torch.equal(output_data, result)) | |
| assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data))) | |
| assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data))) | |
| Utils.destroy_model_parallel() | |
| def test_ReduceFromModelParallelRegion(): | |
| Utils.initialize_model_parallel(4,2) | |
| input_data = torch.ones((1)).cuda()*Utils.rank | |
| output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data) | |
| result = torch.ones(1).cuda() | |
| result = result * 22 if Utils.rank >= 4 else result * 6 | |
| assert(torch.equal(output_data, result)) | |
| input_data = torch.ones((1)).cuda()*Utils.rank | |
| assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result)) | |
| assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data))) | |
| Utils.destroy_model_parallel() | |
| def test_ScatterToModelParallelRegion(): | |
| Utils.initialize_model_parallel(4,2) | |
| input_data = torch.rand((8,4)).cuda() | |
| output_data = mappings.scatter_to_tensor_model_parallel_region(input_data) | |
| req_dim = int(Utils.rank%(Utils.world_size/2)) | |
| assert(torch.equal(output_data, input_data[:,req_dim].reshape((8,1)))) | |
| output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data) | |
| assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1)))) | |
| input_data = torch.ones(8).cuda() * Utils.rank | |
| actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data) | |
| expected_output = torch.cat(( | |
| torch.ones(8)*0, | |
| torch.ones(8)*1, | |
| torch.ones(8)*2, | |
| torch.ones(8)*3)).cuda() | |
| if (Utils.rank >= 4): | |
| expected_output = expected_output + 4 | |
| assert(torch.equal(actual_output_data, expected_output)) | |
| Utils.destroy_model_parallel() | |
| def test_GatherFromModelParallelRegion(): | |
| Utils.initialize_model_parallel(4,2) | |
| input_data = torch.rand((8,4)).cuda() | |
| req_dim = int(Utils.rank%(Utils.world_size/2)) | |
| output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data) | |
| assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1)))) | |
| input_data = torch.ones(8).cuda() * Utils.rank | |
| actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data) | |
| expected_output = torch.cat(( | |
| torch.ones(8)*0, | |
| torch.ones(8)*1, | |
| torch.ones(8)*2, | |
| torch.ones(8)*3)).cuda() | |
| if (Utils.rank >= 4): | |
| expected_output = expected_output + 4 | |
| assert(torch.equal(actual_output_data, expected_output)) | |
| assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output)) | |
| Utils.destroy_model_parallel() | |
| def test_ScatterToSequenceParallelRegion(): | |
| Utils.initialize_model_parallel(4,2) | |
| input_data = torch.rand((8,4)).cuda() | |
| req_dim = int(Utils.rank%(Utils.world_size/2))*2 | |
| output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data) | |
| assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :])) | |
| output_data = mappings.scatter_to_sequence_parallel_region(input_data) | |
| assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :])) | |
| input_data = torch.ones(4).cuda() * Utils.rank | |
| output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data) | |
| expected_output = torch.concat(( | |
| torch.ones(4)*0, | |
| torch.ones(4)*1, | |
| torch.ones(4)*2, | |
| torch.ones(4)*3)).cuda() | |
| if (Utils.rank >= 4): | |
| expected_output = expected_output + 4 | |
| assert(torch.equal(output_data, expected_output)) | |
| Utils.destroy_model_parallel() | |
| def test_GatherFromSequenceParallelRegion(): | |
| Utils.initialize_model_parallel(4,2) | |
| input_data = torch.ones(4).cuda() * Utils.rank | |
| output_data = mappings.gather_from_sequence_parallel_region(input_data) | |
| expected_output = torch.concat(( | |
| torch.ones(4)*0, | |
| torch.ones(4)*1, | |
| torch.ones(4)*2, | |
| torch.ones(4)*3)).cuda() | |
| if (Utils.rank >= 4): | |
| expected_output = expected_output + 4 | |
| assert(torch.equal(output_data, expected_output)) | |
| assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output)) | |
| input_data = torch.vstack(( | |
| torch.ones(4)*0, | |
| torch.ones(4)*1, | |
| torch.ones(4)*2, | |
| torch.ones(4)*3)).cuda() | |
| class Ctx: | |
| tensor_parallel_output_grad = True | |
| output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data) | |
| expected_output = torch.ones((1,4)).cuda() * 4 * int(Utils.rank % 4) | |
| assert(torch.equal(output_data[0], expected_output)) | |
| Utils.destroy_model_parallel() | |
| def test_ReduceScatterToSequenceParallelRegion(): | |
| Utils.initialize_model_parallel(4,2) | |
| input_data = torch.vstack(( | |
| torch.ones(4)*0, | |
| torch.ones(4)*1, | |
| torch.ones(4)*2, | |
| torch.ones(4)*3)).cuda() | |
| output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data) | |
| expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4) | |
| assert(torch.equal(output_data[0], expected_output)) | |
| assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data) , expected_output.reshape((1,4)))) | |
| input_data = torch.ones(4).cuda() * Utils.rank | |
| output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None,input_data) | |
| expected_output = torch.concat(( | |
| torch.ones(4)*0, | |
| torch.ones(4)*1, | |
| torch.ones(4)*2, | |
| torch.ones(4)*3)).cuda() | |
| if (Utils.rank >= 4): | |
| expected_output = expected_output + 4 | |
| assert(torch.equal(output_data, expected_output)) | |
| Utils.destroy_model_parallel() | |