| |
| |
| |
| |
| @@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads |
| from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel |
| from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel |
| from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig |
| + |
| +# Backward compatibility patch for FSDP module reorganization |
| +import sys |
| +import importlib.util |
| + |
| +spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') |
| +if spec: |
| + custom_fsdp = importlib.util.module_from_spec(spec) |
| + spec.loader.exec_module(custom_fsdp) |
| + sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp |
| + if hasattr(custom_fsdp, 'MegatronFSDP'): |
| + custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP |
| |
| |
| |
| |
| @@ -404,6 +404,7 @@ class TELinear(te.pytorch.Linear): |
| ) |
| |
| for param in self.parameters(): |
| + setattr(param, "parallel_mode", parallel_mode) |
| if is_expert: |
| # Reduce the gradient on the expert_data_parallel group for expert linear layers |
| setattr(param, "allreduce", not self.expert_parallel) |
| |
| |
| |
| |
| @@ -80,6 +80,8 @@ def get_gpt_layer_with_transformer_engine_spec( |
| use_te_op_fuser: Optional[bool] = False, |
| use_kitchen: bool = False, |
| use_te_activation_func: bool = False, |
| + post_self_attn_layernorm: bool = False, |
| + post_mlp_layernorm: bool = False, |
| ) -> ModuleSpec: |
| """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). |
| |
| @@ -182,9 +184,11 @@ def get_gpt_layer_with_transformer_engine_spec( |
| ), |
| ), |
| self_attn_bda=get_bias_dropout_add, |
| + post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, |
| pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, |
| mlp=mlp, |
| mlp_bda=get_bias_dropout_add, |
| + post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, |
| sharded_state_dict_keys_map={ |
| "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", |
| "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", |
| |
| |
| |
| |
| @@ -443,7 +443,7 @@ class GPTModel(LanguageModule): |
| if self.share_embeddings_and_output_weights: |
| output_weight = self.shared_embedding_or_output_weight() |
| |
| - if mtp_in_postprocess: |
| + if mtp_in_postprocess and labels is not None: |
| hidden_states = self.mtp( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| |
| |
| |
| |
| @@ -9,6 +9,7 @@ from typing import Callable, List, Optional |
| |
| import numpy as np |
| import torch |
| +import torch.distributed as dist |
| |
| from .utils import GlobalMemoryBuffer, is_torch_min_version |
| |
| @@ -163,6 +164,213 @@ def get_nccl_options(pg_name, nccl_comm_cfgs): |
| return None |
| |
| |
| +old_new_group = None |
| + |
| + |
| +def monkey_patch_torch_dist(): |
| + print("Applying monkey patch to torch.distributed", flush=True) |
| + global old_new_group |
| + if old_new_group is not None: |
| + return |
| + |
| + old_new_group = dist.new_group |
| + |
| + def new_group(*args, **kwargs): |
| + group = old_new_group(*args, **kwargs) |
| + # skip none nccl group. |
| + if ( |
| + len(args) >= 3 and args[2] == "gloo" or |
| + "backend" in kwargs and kwargs["backend"] == "gloo" |
| + ): |
| + return group |
| + |
| + # Get ranks from arguments |
| + if len(args) >= 1 and args[0] is not None: |
| + ranks = args[0] |
| + elif "ranks" in kwargs and kwargs["ranks"] is not None: |
| + ranks = kwargs["ranks"] |
| + else: |
| + # If no ranks specified, use all ranks in world |
| + ranks = list(range(dist.get_world_size())) |
| + |
| + if len(ranks) == 1: |
| + return group |
| + |
| + group = ReloadableProcessGroup(group, ranks) |
| + return group |
| + |
| + dist.new_group = new_group |
| + |
| + def get_new_function(func): |
| + def new_function(*args, **kwargs): |
| + args = ( |
| + arg.group if isinstance(arg, ReloadableProcessGroup) else arg |
| + for arg in args |
| + ) |
| + kwargs = { |
| + k: (v.group if isinstance(v, ReloadableProcessGroup) else v) |
| + for k, v in kwargs.items() |
| + } |
| + return func(*args, **kwargs) |
| + return new_function |
| + |
| + dist.get_rank = get_new_function(dist.get_rank) |
| + dist.get_world_size = get_new_function(dist.get_world_size) |
| + dist.get_backend = get_new_function(dist.get_backend) |
| + dist.get_global_rank = get_new_function(dist.get_global_rank) |
| + dist.get_group_rank = get_new_function(dist.get_group_rank) |
| + dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks) |
| + |
| + dist.all_reduce = get_new_function(dist.all_reduce) |
| + dist.all_gather = get_new_function(dist.all_gather) |
| + dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor) |
| + dist.all_gather_object = get_new_function(dist.all_gather_object) |
| + dist.all_to_all = get_new_function(dist.all_to_all) |
| + dist.all_to_all_single = get_new_function(dist.all_to_all_single) |
| + dist.broadcast = get_new_function(dist.broadcast) |
| + dist.reduce = get_new_function(dist.reduce) |
| + dist.reduce_scatter = get_new_function(dist.reduce_scatter) |
| + dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor) |
| + dist.scatter = get_new_function(dist.scatter) |
| + dist.gather = get_new_function(dist.gather) |
| + dist.barrier = get_new_function(dist.barrier) |
| + dist.send = get_new_function(dist.send) |
| + dist.recv = get_new_function(dist.recv) |
| + dist._coalescing_manager = get_new_function(dist._coalescing_manager) |
| + |
| + # p2p |
| + old_isend = dist.isend |
| + old_irecv = dist.irecv |
| + |
| + dist.isend = get_new_function(dist.isend) |
| + dist.irecv = get_new_function(dist.irecv) |
| + |
| + def get_new_p2pop_function(func): |
| + def new_function(*args, **kwargs): |
| + def convert(arg): |
| + if isinstance(arg, ReloadableProcessGroup): |
| + return arg.group |
| + elif arg == dist.isend: |
| + arg = old_isend |
| + elif arg == dist.irecv: |
| + arg = old_irecv |
| + return arg |
| + |
| + args = (convert(arg) for arg in args) |
| + kwargs = { |
| + k: convert(v) |
| + for k, v in kwargs.items() |
| + } |
| + return func(*args, **kwargs) |
| + return new_function |
| + |
| + dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__) |
| + dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__) |
| + |
| + |
| + |
| +class ReloadableProcessGroup(torch.distributed.ProcessGroup): |
| + GROUPS = [] |
| + |
| + def __init__(self, group, ranks): |
| + super().__init__( |
| + rank=dist.get_rank(group), |
| + size=dist.get_world_size(group), |
| + ) |
| + #print(f"Creating ReloadableProcessGroup with ranks: {ranks}", flush=True) |
| + self.group = group |
| + self.group_info = { |
| + "ranks": ranks, |
| + } |
| + ReloadableProcessGroup.GROUPS.append(self) |
| + |
| + def __getattr__(self, name): |
| + return getattr(self.group, name) |
| + |
| + @staticmethod |
| + def destroy_process_groups(): |
| + for reloadable_group in ReloadableProcessGroup.GROUPS: |
| + if reloadable_group.group is None: |
| + continue |
| + #print(f"Destroying process group: {reloadable_group.group_info['ranks']}") |
| + dist.destroy_process_group(reloadable_group.group) |
| + del reloadable_group.group |
| + reloadable_group.group = None |
| + |
| + @staticmethod |
| + def reload_process_groups(): |
| + for reloadable_group in ReloadableProcessGroup.GROUPS: |
| + if reloadable_group.group is not None: |
| + continue |
| + #print(f"Reloading process group: {reloadable_group.group_info['ranks']}") |
| + group = old_new_group( |
| + ranks=reloadable_group.group_info["ranks"], |
| + backend="nccl" |
| + ) |
| + reloadable_group.group = group |
| + |
| + def rank(self) -> int: return self.group.rank() |
| + def size(self) -> int: return self.group.size() |
| + def name(self) -> str: return self.group.name() |
| + |
| + def shutdown(self) -> None: |
| + if self.group is not None: |
| + self.group.shutdown() |
| + |
| + def abort(self) -> None: |
| + if self.group is not None: |
| + self.group.abort() |
| + |
| + def _fwd(self, method, *args, **kwargs): |
| + inner = self.group |
| + if inner is None: |
| + raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.") |
| + return getattr(inner, method)(*args, **kwargs) |
| + |
| + def barrier(self, *a, **kw): return self._fwd("barrier", *a, **kw) |
| + def broadcast(self, *a, **kw): return self._fwd("broadcast", *a, **kw) |
| + def allreduce(self, *a, **kw): return self._fwd("allreduce", *a, **kw) |
| + def allreduce_coalesced(self, *a, **kw): return self._fwd("allreduce_coalesced", *a, **kw) |
| + def reduce(self, *a, **kw): return self._fwd("reduce", *a, **kw) |
| + def allgather(self, *a, **kw): return self._fwd("allgather", *a, **kw) |
| + def _allgather_base(self, *a, **kw): return self._fwd("_allgather_base", *a, **kw) |
| + def allgather_coalesced(self, *a, **kw): return self._fwd("allgather_coalesced", *a, **kw) |
| + def allgather_into_tensor_coalesced(self, *a, **kw): return self._fwd("allgather_into_tensor_coalesced", *a, **kw) |
| + def gather(self, *a, **kw): return self._fwd("gather", *a, **kw) |
| + def scatter(self, *a, **kw): return self._fwd("scatter", *a, **kw) |
| + def reduce_scatter(self, *a, **kw): return self._fwd("reduce_scatter", *a, **kw) |
| + def _reduce_scatter_base(self, *a, **kw): return self._fwd("_reduce_scatter_base", *a, **kw) |
| + def reduce_scatter_tensor_coalesced(self, *a, **kw): return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw) |
| + def alltoall_base(self, *a, **kw): return self._fwd("alltoall_base", *a, **kw) |
| + def alltoall(self, *a, **kw): return self._fwd("alltoall", *a, **kw) |
| + def send(self, *a, **kw): return self._fwd("send", *a, **kw) |
| + def recv(self, *a, **kw): return self._fwd("recv", *a, **kw) |
| + def recv_anysource(self, *a, **kw): return self._fwd("recv_anysource", *a, **kw) |
| + |
| + def _start_coalescing(self, *a, **kw): return self._fwd("_start_coalescing", *a, **kw) |
| + def _end_coalescing(self, *a, **kw): return self._fwd("_end_coalescing", *a, **kw) |
| + def _get_backend_name(self): return self._fwd("_get_backend_name") |
| + def _get_backend(self, *a, **kw): return self._fwd("_get_backend", *a, **kw) |
| + def _set_default_backend(self, *a, **kw): return self._fwd("_set_default_backend", *a, **kw) |
| + @property |
| + def bound_device_id(self): return self.group.bound_device_id |
| + @bound_device_id.setter |
| + def bound_device_id(self, dev): self.group.bound_device_id = dev |
| + |
| + |
| +def destroy_process_groups(): |
| + """Destroy all reloadable process groups.""" |
| + ReloadableProcessGroup.destroy_process_groups() |
| + |
| + |
| +def reload_process_groups(): |
| + """Reload all reloadable process groups.""" |
| + ReloadableProcessGroup.reload_process_groups() |
| + |
| + |
| +monkey_patch_torch_dist() |
| + |
| + |
| def create_group( |
| ranks=None, |
| timeout=None, |
| |
| |
| |
| |
| @@ -26,22 +26,22 @@ def _batched_p2p_ops( |
| ops = [] |
| if tensor_send_prev is not None: |
| send_prev_op = torch.distributed.P2POp( |
| - torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group |
| + torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, |
| ) |
| ops.append(send_prev_op) |
| if tensor_recv_prev is not None: |
| recv_prev_op = torch.distributed.P2POp( |
| - torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group |
| + torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, |
| ) |
| ops.append(recv_prev_op) |
| if tensor_send_next is not None: |
| send_next_op = torch.distributed.P2POp( |
| - torch.distributed.isend, tensor_send_next, next_pipeline_rank, group |
| + torch.distributed.isend, tensor_send_next, next_pipeline_rank, |
| ) |
| ops.append(send_next_op) |
| if tensor_recv_next is not None: |
| recv_next_op = torch.distributed.P2POp( |
| - torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group |
| + torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, |
| ) |
| ops.append(recv_next_op) |
| if len(ops) > 0: |
| |
| |
| |
| |
| @@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig): |
| qk_layernorm: bool = False |
| """Whether to apply `normalization` type of normalization to the query and key embeddings.""" |
| |
| + post_self_attn_layernorm: bool = False |
| + post_mlp_layernorm: bool = False |
| + |
| test_mode: bool = False |
| """Whether to run real-time tests.""" |
| |
| |
| |
| |
| |
| @@ -224,6 +224,7 @@ class TransformerLayerSubmodules: |
| input_layernorm: Union[ModuleSpec, type] = IdentityOp |
| self_attention: Union[ModuleSpec, type] = IdentityOp |
| self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp |
| + post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp |
| |
| pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp |
| cross_attention: Union[ModuleSpec, type] = IdentityOp |
| @@ -232,6 +233,7 @@ class TransformerLayerSubmodules: |
| pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp |
| mlp: Union[ModuleSpec, type] = IdentityOp |
| mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp |
| + post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp |
| |
| # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method |
| sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) |
| @@ -336,6 +338,14 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): |
| # [Module 3: BiasDropoutFusion] |
| self.self_attn_bda = build_module(submodules.self_attn_bda) |
| |
| + self.post_self_attn_layernorm = build_module( |
| + submodules.post_self_attn_layernorm, |
| + config=self.config, |
| + hidden_size=self.config.hidden_size, |
| + eps=self.config.layernorm_epsilon, |
| + ) |
| + |
| + |
| # [Module 4: Post SelfAttention] Optional Layernorm after self-attn |
| self.pre_cross_attn_layernorm = build_module( |
| submodules.pre_cross_attn_layernorm, |
| @@ -399,6 +409,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): |
| # [Module 9: BiasDropoutFusion] |
| self.mlp_bda = build_module(submodules.mlp_bda) |
| |
| + self.post_mlp_layernorm = build_module( |
| + submodules.post_mlp_layernorm, |
| + config=self.config, |
| + hidden_size=self.config.hidden_size, |
| + eps=self.config.layernorm_epsilon |
| + ) |
| + |
| self.recompute_input_layernorm = False |
| self.recompute_pre_mlp_layernorm = False |
| self.recompute_mlp = False |
| @@ -535,6 +552,11 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): |
| attention_output_with_bias[0] |
| ) |
| |
| + attention_output, attention_output_bias = attention_output_with_bias |
| + attention_output = self.post_self_attn_layernorm(attention_output) |
| + attention_output_with_bias = (attention_output, attention_output_bias) |
| + |
| + |
| # TODO: could we move `bias_dropout_add_exec_handler` itself |
| # inside the module provided in the `bias_dropout_add_spec` module? |
| nvtx_range_push(suffix="self_attn_bda") |
| @@ -635,6 +657,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): |
| else: |
| mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) |
| |
| + mlp_output, mlp_output_bias = mlp_output_with_bias |
| + mlp_output = self.post_mlp_layernorm(mlp_output) |
| + mlp_output_with_bias = (mlp_output, mlp_output_bias) |
| + |
| if self.recompute_pre_mlp_layernorm: |
| # discard the output of the pre-mlp layernorm and register the recompute |
| # as a gradient hook of mlp_output_with_bias[0] |
| |
| |
| |
| |
| @@ -1191,6 +1191,9 @@ def core_transformer_config_from_args(args, config_class=None): |
| if args.is_hybrid_model: |
| kw_args['is_hybrid_model'] = args.is_hybrid_model |
| |
| + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm |
| + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm |
| + |
| # handle quantization config |
| # NOTE: Kitchen arguments are only added to the namespace when |
| # Kitchen library is available. |
| @@ -1481,6 +1484,10 @@ def _add_network_size_args(parser): |
| action='store_true', |
| help='If set, use original BERT residula connection ' |
| 'ordering.') |
| + group.add_argument('--post-self-attn-layernorm', action='store_true', |
| + help='If set, use post self attention layernorm.') |
| + group.add_argument('--post-mlp-layernorm', action='store_true', |
| + help='If set, use post MLP layernorm.') |
| group.add_argument('--openai-gelu', action='store_true', |
| help='Use OpenAIs GeLU implementation. This option' |
| 'should not be used unless for backward compatibility' |
|
|