JustinTX's picture
Add files using upload-large-folder tool
d7b3a74 verified
diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py
index fe26e8b4..4451f277 100644
--- a/megatron/core/distributed/__init__.py
+++ b/megatron/core/distributed/__init__.py
@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads
from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel
from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig
+
+# Backward compatibility patch for FSDP module reorganization
+import sys
+import importlib.util
+
+spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp')
+if spec:
+ custom_fsdp = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(custom_fsdp)
+ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp
+ if hasattr(custom_fsdp, 'MegatronFSDP'):
+ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP
diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py
index 99c3edc0..26ea5cb4 100644
--- a/megatron/core/extensions/transformer_engine.py
+++ b/megatron/core/extensions/transformer_engine.py
@@ -404,6 +404,7 @@ class TELinear(te.pytorch.Linear):
)
for param in self.parameters():
+ setattr(param, "parallel_mode", parallel_mode)
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, "allreduce", not self.expert_parallel)
diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py
index 002edb92..f7273488 100755
--- a/megatron/core/models/gpt/gpt_layer_specs.py
+++ b/megatron/core/models/gpt/gpt_layer_specs.py
@@ -80,6 +80,8 @@ def get_gpt_layer_with_transformer_engine_spec(
use_te_op_fuser: Optional[bool] = False,
use_kitchen: bool = False,
use_te_activation_func: bool = False,
+ post_self_attn_layernorm: bool = False,
+ post_mlp_layernorm: bool = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
@@ -182,9 +184,11 @@ def get_gpt_layer_with_transformer_engine_spec(
),
),
self_attn_bda=get_bias_dropout_add,
+ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp,
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
+ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp,
sharded_state_dict_keys_map={
"mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
"mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
index df9adc3e..2f4f544a 100644
--- a/megatron/core/models/gpt/gpt_model.py
+++ b/megatron/core/models/gpt/gpt_model.py
@@ -443,7 +443,7 @@ class GPTModel(LanguageModule):
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
- if mtp_in_postprocess:
+ if mtp_in_postprocess and labels is not None:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
index 57332ac3..f3abd642 100644
--- a/megatron/core/parallel_state.py
+++ b/megatron/core/parallel_state.py
@@ -9,6 +9,7 @@ from typing import Callable, List, Optional
import numpy as np
import torch
+import torch.distributed as dist
from .utils import GlobalMemoryBuffer, is_torch_min_version
@@ -163,6 +164,213 @@ def get_nccl_options(pg_name, nccl_comm_cfgs):
return None
+old_new_group = None
+
+
+def monkey_patch_torch_dist():
+ print("Applying monkey patch to torch.distributed", flush=True)
+ global old_new_group
+ if old_new_group is not None:
+ return
+
+ old_new_group = dist.new_group
+
+ def new_group(*args, **kwargs):
+ group = old_new_group(*args, **kwargs)
+ # skip none nccl group.
+ if (
+ len(args) >= 3 and args[2] == "gloo" or
+ "backend" in kwargs and kwargs["backend"] == "gloo"
+ ):
+ return group
+
+ # Get ranks from arguments
+ if len(args) >= 1 and args[0] is not None:
+ ranks = args[0]
+ elif "ranks" in kwargs and kwargs["ranks"] is not None:
+ ranks = kwargs["ranks"]
+ else:
+ # If no ranks specified, use all ranks in world
+ ranks = list(range(dist.get_world_size()))
+
+ if len(ranks) == 1:
+ return group
+
+ group = ReloadableProcessGroup(group, ranks)
+ return group
+
+ dist.new_group = new_group
+
+ def get_new_function(func):
+ def new_function(*args, **kwargs):
+ args = (
+ arg.group if isinstance(arg, ReloadableProcessGroup) else arg
+ for arg in args
+ )
+ kwargs = {
+ k: (v.group if isinstance(v, ReloadableProcessGroup) else v)
+ for k, v in kwargs.items()
+ }
+ return func(*args, **kwargs)
+ return new_function
+
+ dist.get_rank = get_new_function(dist.get_rank)
+ dist.get_world_size = get_new_function(dist.get_world_size)
+ dist.get_backend = get_new_function(dist.get_backend)
+ dist.get_global_rank = get_new_function(dist.get_global_rank)
+ dist.get_group_rank = get_new_function(dist.get_group_rank)
+ dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks)
+
+ dist.all_reduce = get_new_function(dist.all_reduce)
+ dist.all_gather = get_new_function(dist.all_gather)
+ dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor)
+ dist.all_gather_object = get_new_function(dist.all_gather_object)
+ dist.all_to_all = get_new_function(dist.all_to_all)
+ dist.all_to_all_single = get_new_function(dist.all_to_all_single)
+ dist.broadcast = get_new_function(dist.broadcast)
+ dist.reduce = get_new_function(dist.reduce)
+ dist.reduce_scatter = get_new_function(dist.reduce_scatter)
+ dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor)
+ dist.scatter = get_new_function(dist.scatter)
+ dist.gather = get_new_function(dist.gather)
+ dist.barrier = get_new_function(dist.barrier)
+ dist.send = get_new_function(dist.send)
+ dist.recv = get_new_function(dist.recv)
+ dist._coalescing_manager = get_new_function(dist._coalescing_manager)
+
+ # p2p
+ old_isend = dist.isend
+ old_irecv = dist.irecv
+
+ dist.isend = get_new_function(dist.isend)
+ dist.irecv = get_new_function(dist.irecv)
+
+ def get_new_p2pop_function(func):
+ def new_function(*args, **kwargs):
+ def convert(arg):
+ if isinstance(arg, ReloadableProcessGroup):
+ return arg.group
+ elif arg == dist.isend:
+ arg = old_isend
+ elif arg == dist.irecv:
+ arg = old_irecv
+ return arg
+
+ args = (convert(arg) for arg in args)
+ kwargs = {
+ k: convert(v)
+ for k, v in kwargs.items()
+ }
+ return func(*args, **kwargs)
+ return new_function
+
+ dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__)
+ dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__)
+
+
+
+class ReloadableProcessGroup(torch.distributed.ProcessGroup):
+ GROUPS = []
+
+ def __init__(self, group, ranks):
+ super().__init__(
+ rank=dist.get_rank(group),
+ size=dist.get_world_size(group),
+ )
+ #print(f"Creating ReloadableProcessGroup with ranks: {ranks}", flush=True)
+ self.group = group
+ self.group_info = {
+ "ranks": ranks,
+ }
+ ReloadableProcessGroup.GROUPS.append(self)
+
+ def __getattr__(self, name):
+ return getattr(self.group, name)
+
+ @staticmethod
+ def destroy_process_groups():
+ for reloadable_group in ReloadableProcessGroup.GROUPS:
+ if reloadable_group.group is None:
+ continue
+ #print(f"Destroying process group: {reloadable_group.group_info['ranks']}")
+ dist.destroy_process_group(reloadable_group.group)
+ del reloadable_group.group
+ reloadable_group.group = None
+
+ @staticmethod
+ def reload_process_groups():
+ for reloadable_group in ReloadableProcessGroup.GROUPS:
+ if reloadable_group.group is not None:
+ continue
+ #print(f"Reloading process group: {reloadable_group.group_info['ranks']}")
+ group = old_new_group(
+ ranks=reloadable_group.group_info["ranks"],
+ backend="nccl"
+ )
+ reloadable_group.group = group
+
+ def rank(self) -> int: return self.group.rank()
+ def size(self) -> int: return self.group.size()
+ def name(self) -> str: return self.group.name()
+
+ def shutdown(self) -> None:
+ if self.group is not None:
+ self.group.shutdown()
+
+ def abort(self) -> None:
+ if self.group is not None:
+ self.group.abort()
+
+ def _fwd(self, method, *args, **kwargs):
+ inner = self.group
+ if inner is None:
+ raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.")
+ return getattr(inner, method)(*args, **kwargs)
+
+ def barrier(self, *a, **kw): return self._fwd("barrier", *a, **kw)
+ def broadcast(self, *a, **kw): return self._fwd("broadcast", *a, **kw)
+ def allreduce(self, *a, **kw): return self._fwd("allreduce", *a, **kw)
+ def allreduce_coalesced(self, *a, **kw): return self._fwd("allreduce_coalesced", *a, **kw)
+ def reduce(self, *a, **kw): return self._fwd("reduce", *a, **kw)
+ def allgather(self, *a, **kw): return self._fwd("allgather", *a, **kw)
+ def _allgather_base(self, *a, **kw): return self._fwd("_allgather_base", *a, **kw)
+ def allgather_coalesced(self, *a, **kw): return self._fwd("allgather_coalesced", *a, **kw)
+ def allgather_into_tensor_coalesced(self, *a, **kw): return self._fwd("allgather_into_tensor_coalesced", *a, **kw)
+ def gather(self, *a, **kw): return self._fwd("gather", *a, **kw)
+ def scatter(self, *a, **kw): return self._fwd("scatter", *a, **kw)
+ def reduce_scatter(self, *a, **kw): return self._fwd("reduce_scatter", *a, **kw)
+ def _reduce_scatter_base(self, *a, **kw): return self._fwd("_reduce_scatter_base", *a, **kw)
+ def reduce_scatter_tensor_coalesced(self, *a, **kw): return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw)
+ def alltoall_base(self, *a, **kw): return self._fwd("alltoall_base", *a, **kw)
+ def alltoall(self, *a, **kw): return self._fwd("alltoall", *a, **kw)
+ def send(self, *a, **kw): return self._fwd("send", *a, **kw)
+ def recv(self, *a, **kw): return self._fwd("recv", *a, **kw)
+ def recv_anysource(self, *a, **kw): return self._fwd("recv_anysource", *a, **kw)
+
+ def _start_coalescing(self, *a, **kw): return self._fwd("_start_coalescing", *a, **kw)
+ def _end_coalescing(self, *a, **kw): return self._fwd("_end_coalescing", *a, **kw)
+ def _get_backend_name(self): return self._fwd("_get_backend_name")
+ def _get_backend(self, *a, **kw): return self._fwd("_get_backend", *a, **kw)
+ def _set_default_backend(self, *a, **kw): return self._fwd("_set_default_backend", *a, **kw)
+ @property
+ def bound_device_id(self): return self.group.bound_device_id
+ @bound_device_id.setter
+ def bound_device_id(self, dev): self.group.bound_device_id = dev
+
+
+def destroy_process_groups():
+ """Destroy all reloadable process groups."""
+ ReloadableProcessGroup.destroy_process_groups()
+
+
+def reload_process_groups():
+ """Reload all reloadable process groups."""
+ ReloadableProcessGroup.reload_process_groups()
+
+
+monkey_patch_torch_dist()
+
+
def create_group(
ranks=None,
timeout=None,
diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py
index 63ee9d1f..b90b744c 100644
--- a/megatron/core/pipeline_parallel/p2p_communication.py
+++ b/megatron/core/pipeline_parallel/p2p_communication.py
@@ -26,22 +26,22 @@ def _batched_p2p_ops(
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group
+ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank,
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group
+ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank,
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group
+ torch.distributed.isend, tensor_send_next, next_pipeline_rank,
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group
+ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank,
)
ops.append(recv_next_op)
if len(ops) > 0:
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
index 6f557e1f..b295fd35 100644
--- a/megatron/core/transformer/transformer_config.py
+++ b/megatron/core/transformer/transformer_config.py
@@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig):
qk_layernorm: bool = False
"""Whether to apply `normalization` type of normalization to the query and key embeddings."""
+ post_self_attn_layernorm: bool = False
+ post_mlp_layernorm: bool = False
+
test_mode: bool = False
"""Whether to run real-time tests."""
diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py
index 84f22bde..b4807d26 100644
--- a/megatron/core/transformer/transformer_layer.py
+++ b/megatron/core/transformer/transformer_layer.py
@@ -224,6 +224,7 @@ class TransformerLayerSubmodules:
input_layernorm: Union[ModuleSpec, type] = IdentityOp
self_attention: Union[ModuleSpec, type] = IdentityOp
self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
+ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
cross_attention: Union[ModuleSpec, type] = IdentityOp
@@ -232,6 +233,7 @@ class TransformerLayerSubmodules:
pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
mlp: Union[ModuleSpec, type] = IdentityOp
mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
+ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
@@ -336,6 +338,14 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
# [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)
+ self.post_self_attn_layernorm = build_module(
+ submodules.post_self_attn_layernorm,
+ config=self.config,
+ hidden_size=self.config.hidden_size,
+ eps=self.config.layernorm_epsilon,
+ )
+
+
# [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = build_module(
submodules.pre_cross_attn_layernorm,
@@ -399,6 +409,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
# [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)
+ self.post_mlp_layernorm = build_module(
+ submodules.post_mlp_layernorm,
+ config=self.config,
+ hidden_size=self.config.hidden_size,
+ eps=self.config.layernorm_epsilon
+ )
+
self.recompute_input_layernorm = False
self.recompute_pre_mlp_layernorm = False
self.recompute_mlp = False
@@ -535,6 +552,11 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
attention_output_with_bias[0]
)
+ attention_output, attention_output_bias = attention_output_with_bias
+ attention_output = self.post_self_attn_layernorm(attention_output)
+ attention_output_with_bias = (attention_output, attention_output_bias)
+
+
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="self_attn_bda")
@@ -635,6 +657,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
else:
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
+ mlp_output, mlp_output_bias = mlp_output_with_bias
+ mlp_output = self.post_mlp_layernorm(mlp_output)
+ mlp_output_with_bias = (mlp_output, mlp_output_bias)
+
if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
index 24ba8926..4f039fd4 100644
--- a/megatron/training/arguments.py
+++ b/megatron/training/arguments.py
@@ -1191,6 +1191,9 @@ def core_transformer_config_from_args(args, config_class=None):
if args.is_hybrid_model:
kw_args['is_hybrid_model'] = args.is_hybrid_model
+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm
+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm
+
# handle quantization config
# NOTE: Kitchen arguments are only added to the namespace when
# Kitchen library is available.
@@ -1481,6 +1484,10 @@ def _add_network_size_args(parser):
action='store_true',
help='If set, use original BERT residula connection '
'ordering.')
+ group.add_argument('--post-self-attn-layernorm', action='store_true',
+ help='If set, use post self attention layernorm.')
+ group.add_argument('--post-mlp-layernorm', action='store_true',
+ help='If set, use post MLP layernorm.')
group.add_argument('--openai-gelu', action='store_true',
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'