Spaces:
Running on Zero
Running on Zero
| import pytest | |
| import torch | |
| import megatron.core.parallel_state as ps | |
| from tests.unit_tests.test_utilities import Utils | |
| rank = Utils.rank | |
| world_size = Utils.world_size | |
| test_parallel_order = ['tp-cp-ep-dp-pp', 'tp-cp-pp-ep-dp'] | |
| def test_initialize_and_destroy_model_parallel(order): | |
| with pytest.raises(AssertionError): | |
| assert ps.initialize_model_parallel(order=order) | |
| Utils.initialize_distributed() | |
| with pytest.raises(RuntimeError): | |
| assert ps.initialize_model_parallel(tensor_model_parallel_size=2 * world_size, order=order) | |
| with pytest.raises(RuntimeError): | |
| assert ps.initialize_model_parallel( | |
| pipeline_model_parallel_size=2 * world_size, order=order | |
| ) | |
| with pytest.raises(RuntimeError): | |
| assert ps.initialize_model_parallel( | |
| pipeline_model_parallel_size=world_size, | |
| tensor_model_parallel_size=world_size, | |
| order=order, | |
| ) | |
| with pytest.raises(RuntimeError): | |
| assert ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2, order=order) | |
| Utils.initialize_model_parallel( | |
| tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order | |
| ) | |
| assert ps.model_parallel_is_initialized() | |
| assert ps.get_model_parallel_group() is not None | |
| assert ps.get_tensor_model_parallel_group() is not None | |
| assert ps.get_pipeline_model_parallel_group() is not None | |
| assert ps.get_data_parallel_group() is not None | |
| assert ps.get_expert_model_parallel_group() is not None | |
| assert ps.get_expert_tensor_parallel_group() is not None | |
| assert ps.get_expert_data_parallel_group() is not None | |
| assert ps.get_expert_tensor_model_pipeline_parallel_group() is not None | |
| Utils.destroy_model_parallel() | |
| assert ps._MODEL_PARALLEL_GROUP is None | |
| def test_pipeline_parallel_initializations(order): | |
| Utils.initialize_model_parallel( | |
| tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order | |
| ) | |
| assert ps.get_pipeline_model_parallel_first_rank() == rank % 2 | |
| assert ps.get_data_parallel_src_rank() == rank | |
| assert ps.get_pipeline_model_parallel_next_rank() == ((rank + 2) % world_size) | |
| assert ps.get_pipeline_model_parallel_prev_rank() == ((rank - 2) % world_size) | |
| Utils.destroy_model_parallel() | |
| def test_data_parallel_initializations(order): | |
| Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) | |
| assert ps.get_data_parallel_src_rank() == rank | |
| assert ps.get_data_parallel_world_size() == 1 | |
| assert ps.get_data_parallel_rank() == 0 | |
| Utils.destroy_model_parallel() | |
| def test_tensor_model_parellel_world_size(order): | |
| Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order) | |
| assert ps.get_tensor_model_parallel_world_size() == world_size | |
| ps.set_tensor_model_parallel_world_size(None) | |
| assert ps.get_tensor_model_parallel_world_size() == world_size | |
| Utils.destroy_model_parallel() | |
| def test_expert_tensor_parellel_world_size(order): | |
| Utils.initialize_model_parallel(expert_tensor_parallel_size=world_size, order=order) | |
| assert ps.get_expert_tensor_parallel_world_size() == world_size | |
| ps.set_expert_tensor_parallel_world_size(None) | |
| assert ps.get_expert_tensor_parallel_world_size() == world_size | |
| Utils.destroy_model_parallel() | |
| def test_pipeline_model_parallel_world_size(order): | |
| Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) | |
| assert ps.get_pipeline_model_parallel_world_size() == world_size | |
| ps.set_pipeline_model_parallel_world_size(None) | |
| assert ps.get_pipeline_model_parallel_world_size() == world_size | |
| Utils.destroy_model_parallel() | |
| def test_tensor_model_parallel_rank(order): | |
| Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order) | |
| assert ps.get_tensor_model_parallel_rank() == rank | |
| ps.set_tensor_model_parallel_rank(None) | |
| assert ps.get_tensor_model_parallel_rank() == rank | |
| Utils.destroy_model_parallel() | |
| def test_moe_tensor_model_parellel_rank(order): | |
| Utils.initialize_model_parallel(expert_tensor_parallel_size=world_size, order=order) | |
| assert ps.get_expert_tensor_parallel_rank() == rank | |
| ps.set_expert_tensor_parallel_rank(None) | |
| assert ps.get_expert_tensor_parallel_rank() == rank | |
| Utils.destroy_model_parallel() | |
| def test_pipeline_model_parallel_rank(order): | |
| Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) | |
| assert ps.get_pipeline_model_parallel_rank() == rank | |
| ps.set_pipeline_model_parallel_rank(None) | |
| assert ps.get_pipeline_model_parallel_rank() == rank | |
| Utils.destroy_model_parallel() | |
| def test_context_parallel_rank(): | |
| Utils.initialize_model_parallel(context_parallel_size=world_size) | |
| assert ps.get_context_parallel_rank() == rank | |
| Utils.destroy_model_parallel() | |
| def test_expert_model_parallel_rank(): | |
| Utils.initialize_model_parallel(expert_model_parallel_size=world_size) | |
| assert ps.get_expert_model_parallel_rank() == rank | |
| ps.set_expert_model_parallel_rank(None) | |
| assert ps.get_expert_model_parallel_rank() == rank | |
| Utils.destroy_model_parallel() | |
| def test_is_pipeline_first_stage(order): | |
| Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) | |
| assert ps.is_pipeline_first_stage(ignore_virtual=False) == (rank == 0) | |
| assert ps.is_pipeline_first_stage() == (rank == 0) | |
| Utils.destroy_model_parallel() | |
| def test_is_pipeline_last_stage(order): | |
| Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) | |
| assert ps.is_pipeline_last_stage(ignore_virtual=False) == (rank == world_size - 1) | |
| assert ps.is_pipeline_last_stage() == (rank == world_size - 1) | |
| Utils.destroy_model_parallel() | |
| def test_virtual_pipeline_model_parallel_rank(order): | |
| Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) | |
| ps.set_virtual_pipeline_model_parallel_rank(rank) | |
| assert ps.get_virtual_pipeline_model_parallel_rank() == rank | |
| Utils.destroy_model_parallel() | |
| def test_get_tensor_model_parallel_src_rank(order): | |
| Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order) | |
| assert ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size) | |
| Utils.destroy_model_parallel() | |
| def test_different_initialize_order_consistency(src_tp_pp, ep_size): | |
| Utils.initialize_model_parallel( | |
| *src_tp_pp, expert_model_parallel_size=ep_size, order='tp-ep-dp-pp' | |
| ) | |
| tp_rank = ps.get_tensor_model_parallel_rank() | |
| dp_rank = ps.get_data_parallel_rank() | |
| pp_rank = ps.get_pipeline_model_parallel_rank() | |
| ep_rank = ps.get_expert_model_parallel_rank() | |
| tp_g = torch.distributed.get_process_group_ranks(ps.get_tensor_model_parallel_group()) | |
| dp_g = torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False)) | |
| pp_g = torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group()) | |
| dp_no_ep_g = torch.distributed.get_process_group_ranks(ps.get_expert_data_parallel_group()) | |
| cp_g = torch.distributed.get_process_group_ranks(ps.get_context_parallel_group()) | |
| mp_g = torch.distributed.get_process_group_ranks(ps.get_model_parallel_group()) | |
| tp_ep_g = torch.distributed.get_process_group_ranks( | |
| ps.get_expert_tensor_and_model_parallel_group() | |
| ) | |
| tp_dp_g = torch.distributed.get_process_group_ranks( | |
| ps.get_tensor_and_data_parallel_group(False) | |
| ) | |
| Utils.destroy_model_parallel() | |
| Utils.initialize_model_parallel( | |
| *src_tp_pp, expert_model_parallel_size=ep_size, order='tp-pp-ep-dp' | |
| ) | |
| assert tp_rank == ps.get_tensor_model_parallel_rank() | |
| assert dp_rank == ps.get_data_parallel_rank() | |
| assert pp_rank == ps.get_pipeline_model_parallel_rank() | |
| assert ep_rank == ps.get_expert_model_parallel_rank() | |
| assert tp_g == torch.distributed.get_process_group_ranks(ps.get_tensor_model_parallel_group()) | |
| assert dp_g == torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False)) | |
| assert pp_g == torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group()) | |
| assert dp_no_ep_g == torch.distributed.get_process_group_ranks( | |
| ps.get_expert_data_parallel_group() | |
| ) | |
| assert cp_g == torch.distributed.get_process_group_ranks(ps.get_context_parallel_group()) | |
| assert mp_g == torch.distributed.get_process_group_ranks(ps.get_model_parallel_group()) | |
| assert tp_ep_g == torch.distributed.get_process_group_ranks( | |
| ps.get_expert_tensor_and_model_parallel_group() | |
| ) | |
| assert tp_dp_g == torch.distributed.get_process_group_ranks( | |
| ps.get_tensor_and_data_parallel_group(False) | |
| ) | |
| Utils.destroy_model_parallel() | |
| def test_different_initialize_order_unconsistency(src_tp_pp, ep_size): | |
| Utils.initialize_model_parallel( | |
| *src_tp_pp, expert_model_parallel_size=ep_size, order='tp-ep-dp-pp' | |
| ) | |
| tp_g = torch.distributed.get_process_group_ranks(ps.get_tensor_model_parallel_group()) | |
| dp_g = torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False)) | |
| pp_g = torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group()) | |
| cp_g = torch.distributed.get_process_group_ranks(ps.get_context_parallel_group()) | |
| amax_g = torch.distributed.get_process_group_ranks(ps.get_amax_reduction_group(False)) | |
| mp_g = torch.distributed.get_process_group_ranks(ps.get_model_parallel_group()) | |
| Utils.destroy_model_parallel() | |
| Utils.initialize_model_parallel( | |
| *src_tp_pp, expert_model_parallel_size=ep_size, order='tp-pp-ep-dp' | |
| ) | |
| assert tp_g == torch.distributed.get_process_group_ranks(ps.get_tensor_model_parallel_group()) | |
| assert dp_g != torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False)) | |
| assert pp_g != torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group()) | |
| assert cp_g == torch.distributed.get_process_group_ranks(ps.get_context_parallel_group()) | |
| assert amax_g != torch.distributed.get_process_group_ranks(ps.get_amax_reduction_group(False)) | |
| assert mp_g != torch.distributed.get_process_group_ranks(ps.get_model_parallel_group()) | |
| Utils.destroy_model_parallel() | |
| def test_rank_generator_for_tp_dp_pp(nodes, num_gpu, tp, pp, cp, ep): | |
| def golden_rank_result_from_past_code( | |
| world_size: int, | |
| tensor_model_parallel_size: int = 1, | |
| pipeline_model_parallel_size: int = 1, | |
| context_parallel_size: int = 1, | |
| expert_model_parallel_size: int = 1, | |
| ): | |
| data_parallel_size: int = world_size // ( | |
| tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size | |
| ) | |
| num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size | |
| num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size | |
| dp_groups = [] | |
| dp_groups_with_cp = [] | |
| all_data_parallel_group_ranks_with_cp = [] | |
| for i in range(pipeline_model_parallel_size): | |
| start_rank = i * num_pipeline_model_parallel_groups | |
| end_rank = (i + 1) * num_pipeline_model_parallel_groups | |
| for j in range(context_parallel_size * tensor_model_parallel_size): | |
| ranks = range( | |
| start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size | |
| ) | |
| dp_groups.append(list(ranks)) | |
| for j in range(tensor_model_parallel_size): | |
| ranks_with_cp = range(start_rank + j, end_rank, tensor_model_parallel_size) | |
| all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp)) | |
| dp_groups_with_cp.append(list(ranks_with_cp)) | |
| cp_group = [] | |
| for i in range(pipeline_model_parallel_size): | |
| for j in range(data_parallel_size): | |
| start_rank = ( | |
| i * num_pipeline_model_parallel_groups | |
| + j * tensor_model_parallel_size * context_parallel_size | |
| ) | |
| end_rank = ( | |
| i * num_pipeline_model_parallel_groups | |
| + (j + 1) * tensor_model_parallel_size * context_parallel_size | |
| ) | |
| for k in range(tensor_model_parallel_size): | |
| ranks = range(start_rank + k, end_rank, tensor_model_parallel_size) | |
| cp_group.append(list(ranks)) | |
| mp_group = [] | |
| for i in range(data_parallel_size * context_parallel_size): | |
| ranks = [ | |
| data_parallel_group_ranks_with_cp[i] | |
| for data_parallel_group_ranks_with_cp in all_data_parallel_group_ranks_with_cp | |
| ] | |
| mp_group.append(list(ranks)) | |
| tp_group = [] | |
| for i in range(num_tensor_model_parallel_groups): | |
| ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) | |
| tp_group.append(list(ranks)) | |
| pp_group = [] | |
| for i in range(num_pipeline_model_parallel_groups): | |
| ranks = range(i, world_size, num_pipeline_model_parallel_groups) | |
| pp_group.append(list(ranks)) | |
| tp_dp_group = [] | |
| tp_dp_cp_group = [] | |
| tensor_and_data_group_size_with_cp: int = ( | |
| tensor_model_parallel_size * data_parallel_size * context_parallel_size | |
| ) | |
| num_tensor_and_data_groups_with_cp: int = world_size // tensor_and_data_group_size_with_cp | |
| for i in range(num_tensor_and_data_groups_with_cp): | |
| start_rank = i * tensor_and_data_group_size_with_cp | |
| end_rank = start_rank + tensor_and_data_group_size_with_cp | |
| ranks = range(start_rank, end_rank) | |
| tp_dp_cp_group.append(list(ranks)) | |
| for j in range(context_parallel_size): | |
| ranks = [] | |
| for k in range(data_parallel_size): | |
| start_rank = ( | |
| i * tensor_and_data_group_size_with_cp | |
| + j * tensor_model_parallel_size | |
| + k * tensor_model_parallel_size * context_parallel_size | |
| ) | |
| end_rank = start_rank + tensor_model_parallel_size | |
| ranks = ranks + list(range(start_rank, end_rank)) | |
| tp_dp_group.append(list(ranks)) | |
| expert_tp_ep_group = [] | |
| expert_dp_group = [] | |
| expert_data_parallel_size = world_size // ( | |
| tensor_model_parallel_size * pipeline_model_parallel_size * expert_model_parallel_size | |
| ) | |
| all_ranks = torch.arange(world_size).reshape( | |
| ( | |
| pipeline_model_parallel_size, | |
| expert_data_parallel_size, | |
| expert_model_parallel_size, | |
| tensor_model_parallel_size, | |
| ) | |
| ) | |
| # (pp, dp, ep, tp) -> (pp*dp, ep*tp) | |
| tp_ep_rearrange = torch.reshape( | |
| all_ranks, (-1, expert_model_parallel_size * tensor_model_parallel_size) | |
| ) | |
| num_tp_ep_groups = tp_ep_rearrange.shape[0] | |
| for i in range(num_tp_ep_groups): | |
| expert_tensor_and_model_parallel_ranks = tp_ep_rearrange[i].tolist() | |
| expert_tp_ep_group.append(expert_tensor_and_model_parallel_ranks) | |
| # (pp, dp, ep, tp) -> (pp*ep*tp, dp) | |
| expert_dp_rearrange = torch.permute(all_ranks, (0, 2, 3, 1)).reshape( | |
| -1, expert_data_parallel_size | |
| ) | |
| num_expert_dp_groups = world_size // expert_data_parallel_size | |
| for i in range(num_expert_dp_groups): | |
| expert_dp_ranks = expert_dp_rearrange[i].tolist() | |
| expert_dp_group.append(expert_dp_ranks) | |
| return ( | |
| dp_groups, | |
| dp_groups_with_cp, | |
| cp_group, | |
| mp_group, | |
| tp_group, | |
| pp_group, | |
| tp_dp_group, | |
| tp_dp_cp_group, | |
| expert_tp_ep_group, | |
| expert_dp_group, | |
| ) | |
| world_size = nodes * num_gpu | |
| dp = world_size // (tp * pp * cp) | |
| expert_dp = world_size // (tp * ep * pp) | |
| assert dp % ep == 0, f"dp size ({dp}) is not divisible by ep {ep} ." | |
| assert ( | |
| world_size % (tp * pp * cp) == 0 | |
| ), f"world_size ({world_size}) is not divisible by tp {tp} x pp {pp} x cp {cp}." | |
| ( | |
| dp_groups, | |
| dp_groups_with_cp, | |
| cp_group, | |
| mp_group, | |
| tp_group, | |
| pp_group, | |
| tp_dp_group, | |
| tp_dp_cp_group, | |
| expert_tp_ep_group, | |
| expert_dp_group, | |
| ) = golden_rank_result_from_past_code( | |
| world_size=world_size, | |
| tensor_model_parallel_size=tp, | |
| pipeline_model_parallel_size=pp, | |
| context_parallel_size=cp, | |
| expert_model_parallel_size=ep, | |
| ) | |
| rank_generator = ps.RankGenerator(tp=tp, ep=1, dp=dp, pp=pp, cp=cp, order="tp-cp-dp-pp") | |
| expert_rank_generator = ps.RankGenerator( | |
| tp=tp, ep=ep, dp=expert_dp, pp=pp, cp=1, order="tp-ep-dp-pp" | |
| ) | |
| assert dp_groups == rank_generator.get_ranks( | |
| "dp" | |
| ), f"{dp_groups} != {rank_generator.get_ranks('dp')}" | |
| assert dp_groups_with_cp == rank_generator.get_ranks( | |
| 'dp-cp' | |
| ), f"{dp_groups_with_cp} != {rank_generator.get_ranks('dp-cp')}" | |
| assert cp_group == rank_generator.get_ranks( | |
| "cp" | |
| ), f"{cp_group} != {rank_generator.get_ranks('cp')}." | |
| assert mp_group == rank_generator.get_ranks( | |
| "tp-pp" | |
| ), f"{mp_group} != {rank_generator.get_ranks('tp-pp')}" | |
| assert tp_group == rank_generator.get_ranks( | |
| "tp" | |
| ), f"{tp_group} != {rank_generator.get_ranks('tp')}" | |
| assert pp_group == rank_generator.get_ranks( | |
| "pp" | |
| ), f"{pp_group} != {rank_generator.get_ranks('pp')}" | |
| assert tp_dp_group == rank_generator.get_ranks( | |
| "tp-dp" | |
| ), f"{tp_dp_group} != {rank_generator.get_ranks('tp-dp')}" | |
| assert tp_dp_cp_group == rank_generator.get_ranks( | |
| "tp-dp-cp" | |
| ), f"{tp_dp_cp_group} != {rank_generator.get_ranks('tp-dp-cp')}" | |
| assert expert_tp_ep_group == expert_rank_generator.get_ranks( | |
| "tp-ep" | |
| ), f"{expert_tp_ep_group} != {expert_rank_generator.get_ranks('tp-ep')}." | |
| assert expert_dp_group == expert_rank_generator.get_ranks( | |
| "dp" | |
| ), f"{expert_dp_group} != {expert_rank_generator.get_ranks('dp')}." | |