| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| import random |
|
|
| import numpy as np |
| import torch |
|
|
| from nemo.utils import AppState, logging |
|
|
| try: |
| from apex.transformer.log_util import set_logging_level |
|
|
| HAVE_APEX = True |
|
|
| except (ImportError, ModuleNotFoundError): |
|
|
| HAVE_APEX = False |
|
|
| try: |
| from megatron.core import tensor_parallel |
| from megatron.core.parallel_state import ( |
| RankGenerator, |
| get_pipeline_model_parallel_rank, |
| set_expert_model_parallel_rank, |
| set_expert_model_parallel_world_size, |
| set_pipeline_model_parallel_rank, |
| set_pipeline_model_parallel_world_size, |
| set_tensor_model_parallel_rank, |
| set_tensor_model_parallel_world_size, |
| ) |
|
|
| HAVE_MEGATRON_CORE = True |
|
|
| except (ImportError, ModuleNotFoundError): |
|
|
| HAVE_MEGATRON_CORE = False |
|
|
| try: |
| from megatron.core.num_microbatches_calculator import ( |
| ConstantNumMicroBatchesCalculator, |
| get_current_global_batch_size, |
| get_micro_batch_size, |
| get_num_microbatches, |
| init_num_microbatches_calculator, |
| ) |
|
|
| MCORE_MB_CALCULATOR = True |
|
|
| except (ImportError, ModuleNotFoundError): |
| logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") |
| from apex.transformer.microbatches import ConstantNumMicroBatches as ConstantNumMicroBatchesCalculator |
| from apex.transformer.pipeline_parallel.utils import ( |
| get_current_global_batch_size, |
| get_micro_batch_size, |
| get_num_microbatches, |
| ) |
| from apex.transformer.pipeline_parallel.utils import ( |
| setup_microbatch_calculator as init_num_microbatches_calculator, |
| ) |
|
|
| MCORE_MB_CALCULATOR = False |
|
|
|
|
| def initialize_model_parallel_for_nemo( |
| world_size, |
| global_rank, |
| local_rank, |
| tensor_model_parallel_size=1, |
| expert_model_parallel_size=1, |
| pipeline_model_parallel_size=1, |
| virtual_pipeline_model_parallel_size=None, |
| pipeline_model_parallel_split_rank=None, |
| pipeline_model_parallel_comm_backend=None, |
| context_parallel_size=1, |
| encoder_tensor_model_parallel_size=0, |
| encoder_pipeline_model_parallel_size=0, |
| micro_batch_size=None, |
| global_batch_size=None, |
| rampup_batch_size=None, |
| use_fp8=False, |
| init_mpi_proc_group=False, |
| seed=1234, |
| apex_transformer_log_level=30, |
| use_tp_pp_dp_mapping=False, |
| use_te_rng_tracker=False, |
| num_distributed_optimizer_instances=1, |
| ): |
| """ |
| Initialize the model parallel groups for NeMo. |
| """ |
| assert ( |
| pipeline_model_parallel_split_rank is None or pipeline_model_parallel_split_rank == 0 |
| ), "pipeline_model_parallel_split_rank is deprecated." |
| assert encoder_pipeline_model_parallel_size == 0 and ( |
| encoder_tensor_model_parallel_size == 0 or encoder_tensor_model_parallel_size == tensor_model_parallel_size |
| ), ( |
| "encoder_pipeline_model_parallel_size is temporarily " |
| "unavailable. We are working on a refactoring to add it back." |
| ) |
|
|
| |
| app_state = AppState() |
| app_state.global_rank = global_rank |
| app_state.world_size = world_size |
| app_state.local_rank = local_rank |
| app_state.use_tp_pp_dp_mapping = use_tp_pp_dp_mapping |
| app_state.num_distributed_optimizer_instances = num_distributed_optimizer_instances |
| app_state.expert_model_parallel_size = expert_model_parallel_size |
| app_state.tensor_model_parallel_size = tensor_model_parallel_size |
| app_state.pipeline_model_parallel_size = pipeline_model_parallel_size |
| app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size |
| app_state.context_parallel_size = context_parallel_size |
| app_state.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size |
| app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size |
| app_state.pipeline_model_parallel_comm_backend = pipeline_model_parallel_comm_backend |
| app_state.use_fp8 = use_fp8 |
| app_state.init_mpi_proc_group = init_mpi_proc_group |
| ( |
| app_state.tensor_model_parallel_rank, |
| app_state.pipeline_model_parallel_rank, |
| app_state.expert_model_parallel_rank, |
| app_state.model_parallel_size, |
| app_state.data_parallel_size, |
| app_state.pipeline_model_parallel_split_rank, |
| app_state.virtual_pipeline_model_parallel_rank, |
| ) = fake_initialize_model_parallel( |
| world_size=world_size, |
| rank=global_rank, |
| tensor_model_parallel_size_=tensor_model_parallel_size, |
| pipeline_model_parallel_size_=pipeline_model_parallel_size, |
| virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, |
| pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, |
| context_parallel_size_=context_parallel_size, |
| expert_model_parallel_size_=expert_model_parallel_size, |
| encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size, |
| encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size, |
| use_tp_pp_dp_mapping=use_tp_pp_dp_mapping, |
| ) |
|
|
| |
| set_tensor_model_parallel_world_size(app_state.tensor_model_parallel_size) |
| set_tensor_model_parallel_rank(app_state.tensor_model_parallel_rank) |
|
|
| set_expert_model_parallel_world_size(app_state.expert_model_parallel_size) |
| set_expert_model_parallel_rank(app_state.expert_model_parallel_rank) |
|
|
| set_pipeline_model_parallel_world_size( |
| app_state.pipeline_model_parallel_size + app_state.encoder_pipeline_model_parallel_size |
| ) |
| set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank) |
|
|
| tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker) |
| if seed is not None: |
| |
| _set_random_seed(seed) |
|
|
| if global_batch_size and micro_batch_size is not None: |
| |
| if MCORE_MB_CALCULATOR: |
| from megatron.core.num_microbatches_calculator import _GLOBAL_NUM_MICROBATCHES_CALCULATOR |
|
|
| if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None: |
| init_num_microbatches_calculator( |
| rank=global_rank, |
| global_batch_size=global_batch_size, |
| micro_batch_size=micro_batch_size, |
| data_parallel_size=app_state.data_parallel_size, |
| rampup_batch_size=rampup_batch_size, |
| decrease_batch_size_if_needed=False, |
| ) |
| else: |
| if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator): |
| assert get_current_global_batch_size() == global_batch_size |
| assert get_micro_batch_size() == micro_batch_size |
| assert get_num_microbatches() == global_batch_size // ( |
| micro_batch_size * app_state.data_parallel_size |
| ) |
| else: |
| raise Exception("Microbatch calculator already initialized.") |
| else: |
| from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR |
|
|
| if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None: |
| init_num_microbatches_calculator( |
| rank=global_rank, |
| global_batch_size=global_batch_size, |
| micro_batch_size=micro_batch_size, |
| data_parallel_size=app_state.data_parallel_size, |
| rampup_batch_size=rampup_batch_size, |
| decrease_batch_size_if_needed=False, |
| ) |
| else: |
| if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator): |
| assert get_current_global_batch_size() == global_batch_size |
| assert get_micro_batch_size() == micro_batch_size |
| assert get_num_microbatches() == global_batch_size // ( |
| micro_batch_size * app_state.data_parallel_size |
| ) |
| else: |
| raise Exception("Microbatch calculator already initialized.") |
|
|
| app_state._is_megatron_initialized = True |
|
|
| if HAVE_APEX: |
| set_logging_level(apex_transformer_log_level) |
|
|
|
|
| def _set_random_seed(seed_): |
| """Set random seed for reproducability.""" |
| if seed_ is not None and seed_ > 0: |
| |
| seed = seed_ + (100 * get_pipeline_model_parallel_rank()) |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.device_count() > 0: |
| tensor_parallel.model_parallel_cuda_manual_seed(seed) |
| else: |
| raise ValueError('Seed ({}) should be a positive integer.'.format(seed_)) |
|
|
|
|
| def set_jit_fusion_options(): |
| """Set PyTorch JIT layer fusion options.""" |
| |
| if torch.__version__ == "1.10.0a0+0aef44c": |
| |
| torch._C._jit_set_profiling_executor(True) |
| torch._C._jit_set_profiling_mode(True) |
| torch._C._jit_override_can_fuse_on_cpu(False) |
| torch._C._jit_override_can_fuse_on_gpu(False) |
| torch._C._jit_set_texpr_fuser_enabled(False) |
| torch._C._jit_set_nvfuser_enabled(True) |
| torch._C._debug_set_autodiff_subgraph_inlining(False) |
|
|
|
|
| def fake_initialize_model_parallel( |
| world_size, |
| rank, |
| tensor_model_parallel_size_, |
| pipeline_model_parallel_size_, |
| pipeline_model_parallel_split_rank_=None, |
| virtual_pipeline_model_parallel_size_=None, |
| expert_model_parallel_size_=1, |
| expert_tensor_parallel_size_=None, |
| context_parallel_size_=1, |
| encoder_tensor_model_parallel_size_=0, |
| encoder_pipeline_model_parallel_size_=0, |
| use_tp_pp_dp_mapping=False, |
| ): |
| """ |
| Fake initialize model data parallel groups so that we can instantiate model parallel |
| models before DDP is initialized. This is needed because PTL execution flow is init |
| model, init trainer -> call trainer.fit(model). DDP is initialized during .fit. |
| This function is taken from megatron.core.parallel_state and modified so that the |
| distributed groups are not created. |
| We only need the tensor parallel and pipeline parallel ranks to instantiate the model. |
| |
| Arguments: |
| tensor_model_parallel_size: number of GPUs used to parallelize model tensor. |
| pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. |
| context_parallel_size: number of GPUs used to parallelize tokens of each input. |
| |
| Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we |
| use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize |
| the model pipeline. The present function will |
| create 8 tensor model-parallel groups, 4 pipeline model-parallel groups |
| and 8 data-parallel groups as: |
| 8 data_parallel groups: |
| [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] |
| 8 tensor model-parallel groups: |
| [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] |
| 4 pipeline model-parallel groups: |
| [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] |
| Note that for efficiency, the caller should make sure adjacent ranks |
| are on the same DGX box. For example if we are using 2 DGX-1 boxes |
| with a total of 16 GPUs, rank 0 to 7 belong to the first box and |
| ranks 8 to 15 belong to the second box. |
| """ |
|
|
| assert encoder_pipeline_model_parallel_size_ == 0 and ( |
| encoder_tensor_model_parallel_size_ == 0 or encoder_tensor_model_parallel_size_ == tensor_model_parallel_size_ |
| ), ( |
| "encoder_pipeline_model_parallel_size is temporarily " |
| "unavailable. We are working on a refactoring to add it back." |
| ) |
| assert pipeline_model_parallel_split_rank_ is None, "pipeline_model_parallel_split_rank is deprecated." |
|
|
| |
| tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) |
| pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) |
| model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size |
| context_parallel_size = min(context_parallel_size_, world_size) |
|
|
| if encoder_pipeline_model_parallel_size_ is None: |
| encoder_pipeline_model_parallel_size = 0 |
| else: |
| encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size_ |
|
|
| if encoder_tensor_model_parallel_size_ == 0 and encoder_pipeline_model_parallel_size_ > 0: |
| encoder_tensor_model_parallel_size = tensor_model_parallel_size |
| else: |
| encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size_ |
|
|
| if encoder_tensor_model_parallel_size > 0: |
| assert encoder_pipeline_model_parallel_size > 0 |
| assert ( |
| encoder_tensor_model_parallel_size <= tensor_model_parallel_size |
| ), "We do not support encoders with more TP than the decoder." |
|
|
| encoder_model_size = ( |
| encoder_tensor_model_parallel_size * encoder_pipeline_model_parallel_size * context_parallel_size |
| ) |
| decoder_model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size |
| total_model_size = encoder_model_size + decoder_model_size |
|
|
| assert world_size % total_model_size == 0, ( |
| f'world_size: {world_size} must be divisible by total world_size: ' |
| f'(decoder_)tensor_model_parallel_size {tensor_model_parallel_size} ' |
| f'* (decoder_)pipeline_model_parallel_size {pipeline_model_parallel_size} ' |
| f'* (decoder_)context_parallel_size {context_parallel_size} + ' |
| f'encoder_tensor_model_parallel_size {encoder_tensor_model_parallel_size} ' |
| f'* encoder_pipeline_model_parallel_size {encoder_pipeline_model_parallel_size} ' |
| f'* context_parallel_size {context_parallel_size}' |
| ) |
| data_parallel_size = world_size // total_model_size |
|
|
| encoder_world_size = encoder_model_size * data_parallel_size |
| decoder_world_size = decoder_model_size * data_parallel_size |
| assert encoder_world_size + decoder_world_size == world_size |
|
|
| virtual_pipeline_model_parallel_rank = None |
| if virtual_pipeline_model_parallel_size_ is not None: |
| virtual_pipeline_model_parallel_rank = 0 |
|
|
| if encoder_world_size > 0: |
| encoder_rank_generator = RankGenerator( |
| tp=encoder_tensor_model_parallel_size, |
| ep=1, |
| dp=data_parallel_size, |
| pp=encoder_pipeline_model_parallel_size, |
| cp=context_parallel_size, |
| order='tp-cp-ep-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', |
| rank_offset=0, |
| ) |
| else: |
| encoder_rank_generator = None |
|
|
| decoder_rank_generator = RankGenerator( |
| tp=tensor_model_parallel_size, |
| ep=1, |
| dp=data_parallel_size, |
| pp=pipeline_model_parallel_size, |
| cp=context_parallel_size, |
| order='tp-cp-ep-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', |
| rank_offset=encoder_world_size, |
| ) |
| |
| if expert_tensor_parallel_size_ is None: |
| expert_tensor_parallel_size_ = tensor_model_parallel_size |
| expert_tensor_model_pipeline_parallel_size = ( |
| expert_tensor_parallel_size_ * expert_model_parallel_size_ * pipeline_model_parallel_size |
| ) |
| expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size |
| if decoder_world_size % expert_tensor_model_pipeline_parallel_size != 0: |
| raise RuntimeError( |
| f"decoder world_size ({decoder_world_size}) is not divisible by " |
| f"expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})" |
| ) |
|
|
| |
| expert_decoder_rank_generator = RankGenerator( |
| tp=expert_tensor_parallel_size_, |
| ep=expert_model_parallel_size_, |
| dp=expert_data_parallel_size, |
| pp=pipeline_model_parallel_size, |
| cp=1, |
| order='tp-cp-ep-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', |
| rank_offset=encoder_world_size, |
| ) |
|
|
| assert ( |
| not use_tp_pp_dp_mapping |
| or pipeline_model_parallel_size == 1 |
| or expert_data_parallel_size == data_parallel_size |
| ), "When not using pp-last rank ordering, the data parallel size of the attention and moe layers must be the same" |
|
|
| assert decoder_rank_generator.get_ranks("pp") == expert_decoder_rank_generator.get_ranks( |
| "pp" |
| ), f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \ |
| but got {decoder_rank_generator.get_ranks('pp')} and {expert_decoder_rank_generator.get_ranks('pp')}" |
|
|
| def generator_wrapper(group_type, is_expert=False, **kwargs): |
| """The `RankGenerator` class produces a hyper-rectangle for a given set of |
| tensor, pipeline, data, expert, and context parallelism. If we have an encoder, |
| in addition to the default decoder, we essentially instantiate two `RankGenerator` |
| classes to construct the parallelism for each module separately, and we then have |
| to stitch them together for the right groups. For now, this means pp and tp-pp.""" |
| from itertools import cycle |
|
|
| if is_expert: |
| d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs) |
| else: |
| d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) |
|
|
| if encoder_rank_generator is None: |
| for x in d_ranks: |
| yield x |
| return |
| e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs) |
| if group_type == 'pp': |
| |
| |
| for x, y in zip(cycle(e_ranks), d_ranks): |
| yield x + y |
| elif group_type == 'tp-pp': |
| |
| |
| assert len(e_ranks) == len(d_ranks) |
| for x, y in zip(e_ranks, d_ranks): |
| yield x + y |
| else: |
| for x in e_ranks: |
| yield x |
| for x in d_ranks: |
| yield x |
|
|
| |
| all_data_parallel_group_ranks_with_cp = [] |
| for ranks in generator_wrapper('dp'): |
| if rank in ranks: |
| data_parallel_group = list(ranks) |
| logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}') |
|
|
| for ranks_with_cp in generator_wrapper('dp-cp'): |
| all_data_parallel_group_ranks_with_cp.append(ranks_with_cp) |
| if rank in ranks_with_cp: |
| data_parallel_group_with_cp = ranks_with_cp |
| logging.info( |
| f'Rank {rank} has combined group of data parallel and context parallel : {data_parallel_group_with_cp}' |
| ) |
|
|
| data_parallel_rank = data_parallel_group.index(rank) |
| logging.info( |
| f'All data parallel group ranks with context parallel combined: {all_data_parallel_group_ranks_with_cp}' |
| ) |
| logging.info(f'Ranks {rank} has data parallel rank: {data_parallel_rank}') |
|
|
| |
| all_context_parallel_group_ranks = [] |
| for ranks in generator_wrapper('cp'): |
| all_context_parallel_group_ranks.append(ranks) |
| if rank in ranks: |
| context_parallel_group = ranks |
| logging.info(f'Rank {rank} has context parallel group: {context_parallel_group}') |
|
|
| context_parallel_rank = context_parallel_group.index(rank) |
| logging.info(f'All context parallel group ranks: {all_context_parallel_group_ranks}') |
| logging.info(f'Ranks {rank} has context parallel rank: {context_parallel_rank}') |
|
|
| |
| all_model_parallel_group_ranks = [] |
| for ranks in generator_wrapper('tp-pp'): |
| all_model_parallel_group_ranks.append(ranks) |
| if rank in ranks: |
| logging.info(f'Rank {rank} has model parallel group: {list(ranks)}') |
| logging.info(f'All model parallel group ranks: {all_model_parallel_group_ranks}') |
|
|
| |
| all_tensor_model_parallel_group_ranks = [] |
| tensor_model_parallel_group = None |
| for ranks in generator_wrapper('tp'): |
| all_tensor_model_parallel_group_ranks.append(ranks) |
| if rank in ranks: |
| tensor_model_parallel_group = ranks |
| logging.info(f'Rank {rank} has tensor model parallel group: {tensor_model_parallel_group}') |
|
|
| tensor_model_parallel_rank = tensor_model_parallel_group.index(rank) |
|
|
| logging.info(f'All tensor model parallel group ranks: {all_tensor_model_parallel_group_ranks}') |
| logging.info(f'Rank {rank} has tensor model parallel rank: {tensor_model_parallel_rank}') |
|
|
| |
| expert_model_parallel_rank = 0 |
| if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1: |
| for ranks in generator_wrapper('ep', is_expert=True): |
| if rank in ranks: |
| expert_model_parallel_rank = list(ranks).index(rank) |
|
|
| |
| |
| all_pipeline_model_parallel_group_ranks = [] |
| all_embedding_group_ranks = [] |
| pipeline_model_parallel_group = None |
| embedding_group = None |
| embedding_rank = None |
| for ranks in generator_wrapper('pp'): |
| all_pipeline_model_parallel_group_ranks.append(ranks) |
| if rank in ranks: |
| pipeline_model_parallel_group = ranks |
| logging.info(f'Rank {rank} has pipeline model parallel group: {pipeline_model_parallel_group}') |
|
|
| |
| |
| if len(ranks) > 1: |
| embedding_ranks = [ranks[0], ranks[-1]] |
| all_embedding_group_ranks.append(embedding_ranks) |
| else: |
| embedding_ranks = ranks |
| all_embedding_group_ranks.append(list(embedding_ranks)) |
| if rank in embedding_ranks: |
| embedding_group = list(embedding_ranks) |
| logging.info(f'Rank {rank} has embedding group: {embedding_group}') |
|
|
| pipeline_model_parallel_rank = pipeline_model_parallel_group.index(rank) |
| if embedding_group is not None: |
| embedding_rank = embedding_group.index(rank) |
|
|
| logging.info(f'All pipeline model parallel group ranks: {all_pipeline_model_parallel_group_ranks}') |
| logging.info(f'Rank {rank} has pipeline model parallel rank {pipeline_model_parallel_rank}') |
| logging.info(f'All embedding group ranks: {all_pipeline_model_parallel_group_ranks}') |
| logging.info(f'Rank {rank} has embedding rank: {embedding_rank}') |
|
|
| return ( |
| tensor_model_parallel_rank, |
| pipeline_model_parallel_rank, |
| expert_model_parallel_rank, |
| model_parallel_size, |
| data_parallel_size, |
| pipeline_model_parallel_split_rank_, |
| virtual_pipeline_model_parallel_rank, |
| ) |
|
|