| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | from collections import defaultdict |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.distributed import DeviceMesh |
| | from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard |
| | from torch.distributed._composable.replicate import replicate |
| | from torch.distributed._tensor import Replicate, Shard |
| | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper |
| | from torch.distributed.tensor.parallel import ( |
| | ColwiseParallel, |
| | PrepareModuleInput, |
| | PrepareModuleOutput, |
| | RowwiseParallel, |
| | SequenceParallel, |
| | parallelize_module |
| | ) |
| |
|
| | from fla.modules.fused_linear_cross_entropy import LinearLossParallel |
| | from fla.modules.mlp import SwiGLULinearParallel |
| | from fla.modules.parallel import PrepareModuleWeight |
| | from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig |
| | from torchtitan.distributed.parallel_dims import ParallelDims |
| | from torchtitan.tools.logging import logger |
| |
|
| |
|
| | def parallelize_fla( |
| | model: nn.Module, |
| | world_mesh: DeviceMesh, |
| | parallel_dims: ParallelDims, |
| | job_config: JobConfig, |
| | ): |
| | """ |
| | Apply tensor parallelism, activation checkpointing, torch.compile, and data |
| | parallelism to the model. |
| | |
| | NOTE: The passed-in model preferably should be on meta device. Otherwise, |
| | the model must fit on GPU or CPU memory. |
| | """ |
| |
|
| | if parallel_dims.tp_enabled: |
| | if ( |
| | job_config.experimental.enable_async_tensor_parallel |
| | and not job_config.training.compile |
| | ): |
| | raise RuntimeError("Async TP requires --training.compile") |
| | enable_float8_linear = "float8" in job_config.model.converters |
| | apply_tp( |
| | model, |
| | world_mesh["tp"], |
| | loss_parallel=parallel_dims.loss_parallel_enabled, |
| | enable_float8=enable_float8_linear, |
| | enable_async_tp=job_config.experimental.enable_async_tensor_parallel, |
| | ) |
| |
|
| | if job_config.activation_checkpoint.mode != "none": |
| | apply_ac(model, job_config.activation_checkpoint) |
| |
|
| | |
| | if job_config.training.compile: |
| | apply_compile(model) |
| |
|
| | if ( |
| | parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled |
| | ): |
| | if parallel_dims.dp_replicate_enabled: |
| | dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") |
| | else: |
| | dp_mesh_dim_names = ("dp_shard_cp",) |
| |
|
| | apply_fsdp( |
| | model, |
| | world_mesh[tuple(dp_mesh_dim_names)], |
| | param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
| | reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
| | pp_enabled=parallel_dims.pp_enabled, |
| | cpu_offload=job_config.training.enable_cpu_offload, |
| | reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward, |
| | ) |
| |
|
| | if parallel_dims.dp_replicate_enabled: |
| | logger.info("Applied HSDP to the model") |
| | else: |
| | logger.info("Applied FSDP to the model") |
| |
|
| | if parallel_dims.cp_enabled: |
| | logger.info("Applied Context Parallel to the model") |
| |
|
| | if job_config.training.enable_cpu_offload: |
| | logger.info("Applied CPU Offloading to the model") |
| | elif parallel_dims.dp_replicate_enabled: |
| | if world_mesh.ndim > 1: |
| | raise RuntimeError("DDP has not supported > 1D parallelism") |
| | apply_ddp( |
| | model, |
| | world_mesh, |
| | enable_compile=job_config.training.compile, |
| | enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, |
| | ) |
| |
|
| |
|
| | class TPPlan: |
| | def __init__( |
| | self, |
| | model=None, |
| | loss_parallel=False, |
| | enable_float8=False, |
| | ): |
| | self.model = model |
| | self.loss_parallel = loss_parallel |
| | self.enable_float8 = enable_float8 |
| | self.base_model_prefix = getattr(model, "base_model_prefix", "model") |
| |
|
| | |
| | |
| | |
| | try: |
| | from torchao.float8.float8_tensor_parallel import ( |
| | Float8ColwiseParallel, |
| | Float8RowwiseParallel, |
| | PrepareFloat8ModuleInput |
| | ) |
| | except ImportError: |
| | Float8ColwiseParallel = None |
| | Float8RowwiseParallel = None |
| | PrepareFloat8ModuleInput = None |
| | if self.enable_float8 and Float8ColwiseParallel is not None: |
| | self.rowwise_parallel = Float8RowwiseParallel |
| | self.colwise_parallel = Float8ColwiseParallel |
| | self.prepare_module_input = PrepareFloat8ModuleInput |
| | self.prepare_module_output = PrepareModuleOutput |
| | else: |
| | self.rowwise_parallel = RowwiseParallel |
| | self.colwise_parallel = ColwiseParallel |
| | self.prepare_module_input = PrepareModuleInput |
| | self.prepare_module_output = PrepareModuleOutput |
| |
|
| | @property |
| | def model_plan(self): |
| | plans = { |
| | f"{self.base_model_prefix}.embeddings": RowwiseParallel( |
| | input_layouts=Replicate(), |
| | output_layouts=Shard(1), |
| | ), |
| | f"{self.base_model_prefix}.norm": SequenceParallel(), |
| | } |
| | if self.loss_parallel: |
| | plans.update( |
| | { |
| | "lm_head": ColwiseParallel( |
| | input_layouts=Shard(1), |
| | output_layouts=Shard(-1) if self.loss_parallel else Replicate(), |
| | use_local_output=not self.loss_parallel, |
| | ), |
| | } |
| | ) |
| | else: |
| | plans.update( |
| | { |
| | "lm_head": PrepareModuleWeight(layouts=Replicate()), |
| | "criterion": LinearLossParallel(), |
| | } |
| | ) |
| | return plans |
| |
|
| | @property |
| | def layer_plan(self): |
| | return { |
| | "attn_norm": SequenceParallel(), |
| | **self.attn_plan, |
| | "mlp_norm": SequenceParallel(), |
| | **self.mlp_plan, |
| | } |
| |
|
| | @property |
| | def attn_plan(self): |
| | raise NotImplementedError( |
| | f"TP plans for token mixing layers of {self.model.config.model_type} not implemented" |
| | ) |
| |
|
| | @property |
| | def mlp_plan(self): |
| | return { |
| | "mlp": self.prepare_module_input( |
| | input_layouts=(Shard(1),), |
| | desired_input_layouts=(Replicate(),), |
| | ), |
| | "mlp.gate_proj": self.colwise_parallel(), |
| | "mlp.up_proj": self.colwise_parallel(), |
| | "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)), |
| | "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)), |
| | } |
| |
|
| |
|
| | class TransformerTPPlan(TPPlan): |
| |
|
| | @property |
| | def attn_plan(self): |
| | return { |
| | "attn": self.prepare_module_input( |
| | input_kwarg_layouts={"hidden_states": Shard(1)}, |
| | desired_input_kwarg_layouts={"hidden_states": Replicate()}, |
| | ), |
| | "attn.q_proj": self.colwise_parallel(), |
| | "attn.k_proj": self.colwise_parallel(), |
| | "attn.v_proj": self.colwise_parallel(), |
| | "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)), |
| | } |
| |
|
| |
|
| | class GLATPPlan(TPPlan): |
| |
|
| | @property |
| | def attn_plan(self): |
| | return { |
| | "attn": self.prepare_module_input( |
| | input_kwarg_layouts={"hidden_states": Shard(1)}, |
| | desired_input_kwarg_layouts={"hidden_states": Replicate()}, |
| | ), |
| | "attn.q_proj": self.colwise_parallel(), |
| | "attn.k_proj": self.colwise_parallel(), |
| | "attn.v_proj": self.colwise_parallel(), |
| | "attn.g_proj": self.colwise_parallel(), |
| | "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()), |
| | "attn.gk_proj.1": self.colwise_parallel(), |
| | "attn.g_norm": SequenceParallel(sequence_dim=-1), |
| | "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)), |
| | } |
| |
|
| |
|
| | TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan} |
| |
|
| |
|
| | def apply_tp( |
| | model: nn.Module, |
| | tp_mesh: DeviceMesh, |
| | loss_parallel: bool, |
| | enable_float8: bool, |
| | enable_async_tp: bool, |
| | ): |
| | """Apply tensor parallelism.""" |
| | |
| | |
| | |
| | |
| | tp_plan = TP_PLAN_MAP[model.config.model_type]( |
| | model, loss_parallel=loss_parallel, enable_float8=enable_float8 |
| | ) |
| | parallelize_module(model, tp_mesh, tp_plan.model_plan) |
| |
|
| | blocks = get_blocks(model) |
| | if blocks is None: |
| | logger.warning("No block found for tensor parallelism") |
| | else: |
| | for _, block in enumerate(blocks): |
| | parallelize_module( |
| | module=block, |
| | device_mesh=tp_mesh, |
| | parallelize_plan=tp_plan.layer_plan, |
| | ) |
| |
|
| | if enable_async_tp: |
| | from torch.distributed._symmetric_memory import enable_symm_mem_for_group |
| |
|
| | torch._inductor.config._micro_pipeline_tp = True |
| | enable_symm_mem_for_group(tp_mesh.get_group().group_name) |
| |
|
| | logger.info( |
| | f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" |
| | "Tensor Parallelism to the model" |
| | ) |
| |
|
| |
|
| | |
| | _save_list = { |
| | torch.ops.aten.mm.default, |
| | torch.ops.aten._scaled_dot_product_efficient_attention.default, |
| | torch.ops.aten._scaled_dot_product_flash_attention.default, |
| | torch.ops._c10d_functional.reduce_scatter_tensor.default, |
| | |
| | |
| | |
| | torch.ops.aten.max.default, |
| | } |
| |
|
| |
|
| | def _apply_ac_to_block(module: nn.Module, ac_config): |
| | valid_ac_modes = ("full", "selective") |
| | if ac_config.mode not in valid_ac_modes: |
| | raise ValueError( |
| | f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" |
| | ) |
| |
|
| | if ac_config.mode == "full": |
| | return ptd_checkpoint_wrapper(module, preserve_rng_state=False) |
| |
|
| | assert ac_config.mode == "selective", f"{ac_config.mode}" |
| | use_op_sac = ac_config.selective_ac_option == "op" |
| | use_layer_sac = ac_config.selective_ac_option.isdigit() |
| | if not use_op_sac and not use_layer_sac: |
| | raise ValueError( |
| | f"Invalid selective AC option: {ac_config.selective_ac_option}. " |
| | f"Valid options: 'op' or a positive int representing layer frequency" |
| | ) |
| | if use_op_sac: |
| | from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts |
| |
|
| | def _get_custom_policy(meta): |
| | def _custom_policy(ctx, func, *args, **kwargs): |
| | mode = "recompute" if ctx.is_recompute else "forward" |
| | mm_count_key = f"{mode}_mm_count" |
| | if func == torch.ops.aten.mm.default: |
| | meta[mm_count_key] += 1 |
| | |
| | to_save = func in _save_list and not ( |
| | func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 |
| | ) |
| | return ( |
| | CheckpointPolicy.MUST_SAVE |
| | if to_save |
| | else CheckpointPolicy.PREFER_RECOMPUTE |
| | ) |
| |
|
| | return _custom_policy |
| |
|
| | def selective_checkpointing_context_fn(): |
| | meta = defaultdict(int) |
| | return create_selective_checkpoint_contexts(_get_custom_policy(meta)) |
| |
|
| | return ptd_checkpoint_wrapper( |
| | module, |
| | context_fn=selective_checkpointing_context_fn, |
| | preserve_rng_state=False, |
| | ) |
| | elif use_layer_sac: |
| | |
| | ac_freq = int(ac_config.selective_ac_option) |
| | ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) |
| | ptd_checkpoint_wrapper._count += 1 |
| | if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: |
| | return ptd_checkpoint_wrapper(module, preserve_rng_state=False) |
| | else: |
| | return module |
| |
|
| |
|
| | def apply_ac(model: nn.Module, ac_config): |
| | """Apply activation checkpointing to the model.""" |
| | blocks = get_blocks(model) |
| | if blocks is None: |
| | logger.warning("No block found for activation checkpointing") |
| | return |
| |
|
| | for layer_id, block in blocks.named_children(): |
| | block = _apply_ac_to_block(block, ac_config) |
| | blocks.register_module(layer_id, block) |
| |
|
| | logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") |
| |
|
| |
|
| | def apply_compile(model: nn.Module): |
| | """ |
| | Apply torch.compile to each block, which makes compilation efficient due to |
| | repeated structure. Alternatively one can compile the whole model (after applying DP). |
| | """ |
| |
|
| | blocks = get_blocks(model) |
| | if blocks is None: |
| | logger.warning("No block found for torch.compile") |
| | else: |
| | for layer_id, block in blocks.named_children(): |
| | block = torch.compile(block) |
| | blocks.register_module(layer_id, block) |
| | logger.info("Compiling each block with torch.compile") |
| |
|
| | real_model = get_model(model) |
| |
|
| | logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile") |
| | embeddings_key = get_components_name(real_model, "tok_embeddings") |
| | if embeddings_key is not None: |
| | embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True) |
| | real_model.register_module(embeddings_key, embeddings) |
| |
|
| | norm_key = get_components_name(real_model, "norm") |
| | if norm_key is not None: |
| | norm = torch.compile(getattr(real_model, norm_key), fullgraph=True) |
| | real_model.register_module(norm_key, norm) |
| |
|
| | lm_head_key = get_components_name(model, "lm_head") |
| | if lm_head_key is not None: |
| | lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True) |
| | model.register_module(lm_head_key, lm_head) |
| |
|
| | logger.info("Compiling the entire model with torch.compile") |
| | model = torch.compile(model) |
| |
|
| |
|
| | def apply_fsdp( |
| | model: nn.Module, |
| | dp_mesh: DeviceMesh, |
| | param_dtype: torch.dtype, |
| | reduce_dtype: torch.dtype, |
| | pp_enabled: bool, |
| | cpu_offload: bool = False, |
| | reshard_after_forward_policy: str = "default", |
| | ): |
| | """ |
| | Apply data parallelism (via FSDP2) to the model. |
| | |
| | Args: |
| | model (nn.Module): The model to apply data parallelism to. |
| | dp_mesh (DeviceMesh): The device mesh to use for data parallelism. |
| | param_dtype (torch.dtype): The data type to use for model parameters. |
| | reduce_dtype (torch.dtype): The data type to use for reduction operations. |
| | pp_enabled (bool): Whether pipeline parallelism is enabled. |
| | cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. |
| | reshard_after_forward_policy (str, optional): |
| | The policy to use for resharding after forward pass. Defaults to "default". |
| | Other options: "never", "always". |
| | - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. |
| | - "always" will enable `reshard_after_forward` for all forward passes. |
| | - "never" will disable `reshard_after_forward` for all forward passes. |
| | |
| | """ |
| | mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) |
| | fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} |
| | if cpu_offload: |
| | fsdp_config["offload_policy"] = CPUOffloadPolicy() |
| |
|
| | blocks = get_blocks(model) |
| | if blocks is None: |
| | logger.warning("No block found for FSDP") |
| | else: |
| | total_blocks = len(blocks) |
| | for layer_id, block in enumerate(blocks): |
| | if reshard_after_forward_policy == "always": |
| | reshard_after_forward = True |
| | elif reshard_after_forward_policy == "never": |
| | reshard_after_forward = False |
| | elif reshard_after_forward_policy == "default": |
| | if pp_enabled: |
| | |
| | |
| | reshard_after_forward = False |
| | else: |
| | |
| | |
| | reshard_after_forward = int(layer_id) < total_blocks - 1 |
| | else: |
| | raise ValueError( |
| | f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." |
| | ) |
| | fully_shard( |
| | block, |
| | **fsdp_config, |
| | reshard_after_forward=reshard_after_forward, |
| | ) |
| |
|
| | fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) |
| |
|
| |
|
| | def apply_ddp( |
| | model: nn.Module, |
| | dp_mesh: DeviceMesh, |
| | enable_compile: bool, |
| | enable_compiled_autograd: bool, |
| | ): |
| | if enable_compile: |
| | if enable_compiled_autograd: |
| | torch._dynamo.config.optimize_ddp = ( |
| | "python_reducer_without_compiled_forward" |
| | ) |
| | else: |
| | torch._dynamo.config.optimize_ddp = "ddp_optimizer" |
| |
|
| | replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) |
| |
|
| | logger.info("Applied DDP to the model") |
| |
|
| |
|
| | def get_model(model): |
| | base_model_prefix = getattr(model, "base_model_prefix", "model") |
| | if not hasattr(model, base_model_prefix): |
| | return None |
| | model = getattr(model, base_model_prefix) |
| | return model |
| |
|
| |
|
| | def get_blocks(model): |
| | |
| | model = get_model(model) |
| | if not hasattr(model, "layers"): |
| | logger.warning('no "layers" in model can be found') |
| | return None |
| | return model.layers |
| |
|
| |
|
| | def get_components_name(model, component_name): |
| | """ |
| | We try to catch tok_embeddings, norm layers and lm_head layers |
| | We do not catch the layer names in the blocks, for blocks see `get_blocks` |
| | We assume the model has the following structure: |
| | LlamaForCausalLM: |
| | Model: |
| | embed_tokens, |
| | layers, |
| | norm, |
| | lm_head |
| | *** |
| | so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)` |
| | and for 'lm_head' we need to pass `model` |
| | *** |
| | """ |
| |
|
| | if component_name == "tok_embeddings": |
| | if hasattr(model, "tok_embeddings"): |
| | return "tok_embeddings" |
| | elif hasattr(model, "embed_tokens"): |
| | return "embed_tokens" |
| | elif hasattr(model, "embeddings"): |
| | return "embeddings" |
| | else: |
| | logger.warning("No tok_embeddings found in model") |
| | return None |
| |
|
| | elif component_name == "norm": |
| | if hasattr(model, "norm"): |
| | return "norm" |
| | elif hasattr(model, "norms"): |
| | return "norms" |
| | elif hasattr(model, "layernorm"): |
| | return "layernorm" |
| | else: |
| | logger.warning("No norm found in model") |
| | return None |
| |
|
| | elif component_name == "lm_head": |
| | if hasattr(model, "lm_head"): |
| | return "lm_head" |
| | else: |
| | logger.warning("No lm_head found in model") |
| | return None |
| |
|