diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py index fe26e8b4..4451f277 100644 --- a/megatron/core/distributed/__init__.py +++ b/megatron/core/distributed/__init__.py @@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig + +# Backward compatibility patch for FSDP module reorganization +import sys +import importlib.util + +spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') +if spec: + custom_fsdp = importlib.util.module_from_spec(spec) + spec.loader.exec_module(custom_fsdp) + sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp + if hasattr(custom_fsdp, 'MegatronFSDP'): + custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 99c3edc0..26ea5cb4 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -404,6 +404,7 @@ class TELinear(te.pytorch.Linear): ) for param in self.parameters(): + setattr(param, "parallel_mode", parallel_mode) if is_expert: # Reduce the gradient on the expert_data_parallel group for expert linear layers setattr(param, "allreduce", not self.expert_parallel) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 002edb92..f7273488 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -80,6 +80,8 @@ def get_gpt_layer_with_transformer_engine_spec( use_te_op_fuser: Optional[bool] = False, use_kitchen: bool = False, use_te_activation_func: bool = False, + post_self_attn_layernorm: bool = False, + post_mlp_layernorm: bool = False, ) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). @@ -182,9 +184,11 @@ def get_gpt_layer_with_transformer_engine_spec( ), ), self_attn_bda=get_bias_dropout_add, + post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, + post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, sharded_state_dict_keys_map={ "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index df9adc3e..2f4f544a 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -443,7 +443,7 @@ class GPTModel(LanguageModule): if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if mtp_in_postprocess: + if mtp_in_postprocess and labels is not None: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 57332ac3..f3abd642 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -9,6 +9,7 @@ from typing import Callable, List, Optional import numpy as np import torch +import torch.distributed as dist from .utils import GlobalMemoryBuffer, is_torch_min_version @@ -163,6 +164,213 @@ def get_nccl_options(pg_name, nccl_comm_cfgs): return None +old_new_group = None + + +def monkey_patch_torch_dist(): + print("Applying monkey patch to torch.distributed", flush=True) + global old_new_group + if old_new_group is not None: + return + + old_new_group = dist.new_group + + def new_group(*args, **kwargs): + group = old_new_group(*args, **kwargs) + # skip none nccl group. + if ( + len(args) >= 3 and args[2] == "gloo" or + "backend" in kwargs and kwargs["backend"] == "gloo" + ): + return group + + # Get ranks from arguments + if len(args) >= 1 and args[0] is not None: + ranks = args[0] + elif "ranks" in kwargs and kwargs["ranks"] is not None: + ranks = kwargs["ranks"] + else: + # If no ranks specified, use all ranks in world + ranks = list(range(dist.get_world_size())) + + if len(ranks) == 1: + return group + + group = ReloadableProcessGroup(group, ranks) + return group + + dist.new_group = new_group + + def get_new_function(func): + def new_function(*args, **kwargs): + args = ( + arg.group if isinstance(arg, ReloadableProcessGroup) else arg + for arg in args + ) + kwargs = { + k: (v.group if isinstance(v, ReloadableProcessGroup) else v) + for k, v in kwargs.items() + } + return func(*args, **kwargs) + return new_function + + dist.get_rank = get_new_function(dist.get_rank) + dist.get_world_size = get_new_function(dist.get_world_size) + dist.get_backend = get_new_function(dist.get_backend) + dist.get_global_rank = get_new_function(dist.get_global_rank) + dist.get_group_rank = get_new_function(dist.get_group_rank) + dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks) + + dist.all_reduce = get_new_function(dist.all_reduce) + dist.all_gather = get_new_function(dist.all_gather) + dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor) + dist.all_gather_object = get_new_function(dist.all_gather_object) + dist.all_to_all = get_new_function(dist.all_to_all) + dist.all_to_all_single = get_new_function(dist.all_to_all_single) + dist.broadcast = get_new_function(dist.broadcast) + dist.reduce = get_new_function(dist.reduce) + dist.reduce_scatter = get_new_function(dist.reduce_scatter) + dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor) + dist.scatter = get_new_function(dist.scatter) + dist.gather = get_new_function(dist.gather) + dist.barrier = get_new_function(dist.barrier) + dist.send = get_new_function(dist.send) + dist.recv = get_new_function(dist.recv) + dist._coalescing_manager = get_new_function(dist._coalescing_manager) + + # p2p + old_isend = dist.isend + old_irecv = dist.irecv + + dist.isend = get_new_function(dist.isend) + dist.irecv = get_new_function(dist.irecv) + + def get_new_p2pop_function(func): + def new_function(*args, **kwargs): + def convert(arg): + if isinstance(arg, ReloadableProcessGroup): + return arg.group + elif arg == dist.isend: + arg = old_isend + elif arg == dist.irecv: + arg = old_irecv + return arg + + args = (convert(arg) for arg in args) + kwargs = { + k: convert(v) + for k, v in kwargs.items() + } + return func(*args, **kwargs) + return new_function + + dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__) + dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__) + + + +class ReloadableProcessGroup(torch.distributed.ProcessGroup): + GROUPS = [] + + def __init__(self, group, ranks): + super().__init__( + rank=dist.get_rank(group), + size=dist.get_world_size(group), + ) + #print(f"Creating ReloadableProcessGroup with ranks: {ranks}", flush=True) + self.group = group + self.group_info = { + "ranks": ranks, + } + ReloadableProcessGroup.GROUPS.append(self) + + def __getattr__(self, name): + return getattr(self.group, name) + + @staticmethod + def destroy_process_groups(): + for reloadable_group in ReloadableProcessGroup.GROUPS: + if reloadable_group.group is None: + continue + #print(f"Destroying process group: {reloadable_group.group_info['ranks']}") + dist.destroy_process_group(reloadable_group.group) + del reloadable_group.group + reloadable_group.group = None + + @staticmethod + def reload_process_groups(): + for reloadable_group in ReloadableProcessGroup.GROUPS: + if reloadable_group.group is not None: + continue + #print(f"Reloading process group: {reloadable_group.group_info['ranks']}") + group = old_new_group( + ranks=reloadable_group.group_info["ranks"], + backend="nccl" + ) + reloadable_group.group = group + + def rank(self) -> int: return self.group.rank() + def size(self) -> int: return self.group.size() + def name(self) -> str: return self.group.name() + + def shutdown(self) -> None: + if self.group is not None: + self.group.shutdown() + + def abort(self) -> None: + if self.group is not None: + self.group.abort() + + def _fwd(self, method, *args, **kwargs): + inner = self.group + if inner is None: + raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.") + return getattr(inner, method)(*args, **kwargs) + + def barrier(self, *a, **kw): return self._fwd("barrier", *a, **kw) + def broadcast(self, *a, **kw): return self._fwd("broadcast", *a, **kw) + def allreduce(self, *a, **kw): return self._fwd("allreduce", *a, **kw) + def allreduce_coalesced(self, *a, **kw): return self._fwd("allreduce_coalesced", *a, **kw) + def reduce(self, *a, **kw): return self._fwd("reduce", *a, **kw) + def allgather(self, *a, **kw): return self._fwd("allgather", *a, **kw) + def _allgather_base(self, *a, **kw): return self._fwd("_allgather_base", *a, **kw) + def allgather_coalesced(self, *a, **kw): return self._fwd("allgather_coalesced", *a, **kw) + def allgather_into_tensor_coalesced(self, *a, **kw): return self._fwd("allgather_into_tensor_coalesced", *a, **kw) + def gather(self, *a, **kw): return self._fwd("gather", *a, **kw) + def scatter(self, *a, **kw): return self._fwd("scatter", *a, **kw) + def reduce_scatter(self, *a, **kw): return self._fwd("reduce_scatter", *a, **kw) + def _reduce_scatter_base(self, *a, **kw): return self._fwd("_reduce_scatter_base", *a, **kw) + def reduce_scatter_tensor_coalesced(self, *a, **kw): return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw) + def alltoall_base(self, *a, **kw): return self._fwd("alltoall_base", *a, **kw) + def alltoall(self, *a, **kw): return self._fwd("alltoall", *a, **kw) + def send(self, *a, **kw): return self._fwd("send", *a, **kw) + def recv(self, *a, **kw): return self._fwd("recv", *a, **kw) + def recv_anysource(self, *a, **kw): return self._fwd("recv_anysource", *a, **kw) + + def _start_coalescing(self, *a, **kw): return self._fwd("_start_coalescing", *a, **kw) + def _end_coalescing(self, *a, **kw): return self._fwd("_end_coalescing", *a, **kw) + def _get_backend_name(self): return self._fwd("_get_backend_name") + def _get_backend(self, *a, **kw): return self._fwd("_get_backend", *a, **kw) + def _set_default_backend(self, *a, **kw): return self._fwd("_set_default_backend", *a, **kw) + @property + def bound_device_id(self): return self.group.bound_device_id + @bound_device_id.setter + def bound_device_id(self, dev): self.group.bound_device_id = dev + + +def destroy_process_groups(): + """Destroy all reloadable process groups.""" + ReloadableProcessGroup.destroy_process_groups() + + +def reload_process_groups(): + """Reload all reloadable process groups.""" + ReloadableProcessGroup.reload_process_groups() + + +monkey_patch_torch_dist() + + def create_group( ranks=None, timeout=None, diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index 63ee9d1f..b90b744c 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group + torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, ) ops.append(send_prev_op) if tensor_recv_prev is not None: recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group + torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, ) ops.append(recv_prev_op) if tensor_send_next is not None: send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_next, next_pipeline_rank, group + torch.distributed.isend, tensor_send_next, next_pipeline_rank, ) ops.append(send_next_op) if tensor_recv_next is not None: recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group + torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, ) ops.append(recv_next_op) if len(ops) > 0: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 6f557e1f..b295fd35 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig): qk_layernorm: bool = False """Whether to apply `normalization` type of normalization to the query and key embeddings.""" + post_self_attn_layernorm: bool = False + post_mlp_layernorm: bool = False + test_mode: bool = False """Whether to run real-time tests.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 84f22bde..b4807d26 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -224,6 +224,7 @@ class TransformerLayerSubmodules: input_layernorm: Union[ModuleSpec, type] = IdentityOp self_attention: Union[ModuleSpec, type] = IdentityOp self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp + post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp @@ -232,6 +233,7 @@ class TransformerLayerSubmodules: pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp + post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) @@ -336,6 +338,14 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): # [Module 3: BiasDropoutFusion] self.self_attn_bda = build_module(submodules.self_attn_bda) + self.post_self_attn_layernorm = build_module( + submodules.post_self_attn_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn self.pre_cross_attn_layernorm = build_module( submodules.pre_cross_attn_layernorm, @@ -399,6 +409,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): # [Module 9: BiasDropoutFusion] self.mlp_bda = build_module(submodules.mlp_bda) + self.post_mlp_layernorm = build_module( + submodules.post_mlp_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon + ) + self.recompute_input_layernorm = False self.recompute_pre_mlp_layernorm = False self.recompute_mlp = False @@ -535,6 +552,11 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): attention_output_with_bias[0] ) + attention_output, attention_output_bias = attention_output_with_bias + attention_output = self.post_self_attn_layernorm(attention_output) + attention_output_with_bias = (attention_output, attention_output_bias) + + # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? nvtx_range_push(suffix="self_attn_bda") @@ -635,6 +657,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): else: mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + mlp_output, mlp_output_bias = mlp_output_with_bias + mlp_output = self.post_mlp_layernorm(mlp_output) + mlp_output_with_bias = (mlp_output, mlp_output_bias) + if self.recompute_pre_mlp_layernorm: # discard the output of the pre-mlp layernorm and register the recompute # as a gradient hook of mlp_output_with_bias[0] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 24ba8926..4f039fd4 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1191,6 +1191,9 @@ def core_transformer_config_from_args(args, config_class=None): if args.is_hybrid_model: kw_args['is_hybrid_model'] = args.is_hybrid_model + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm + # handle quantization config # NOTE: Kitchen arguments are only added to the namespace when # Kitchen library is available. @@ -1481,6 +1484,10 @@ def _add_network_size_args(parser): action='store_true', help='If set, use original BERT residula connection ' 'ordering.') + group.add_argument('--post-self-attn-layernorm', action='store_true', + help='If set, use post self attention layernorm.') + group.add_argument('--post-mlp-layernorm', action='store_true', + help='If set, use post MLP layernorm.') group.add_argument('--openai-gelu', action='store_true', help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility'