| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from torch.distributed import DeviceMesh |
|
|
| from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP |
| from torchtitan.distributed import ParallelDims |
| from torchtitan.models.llama3.parallelize_llama import apply_ac |
| from torchtitan.tools.logging import logger |
|
|
| from .simple_fsdp import data_parallel, MixedPrecisionPolicy |
|
|
|
|
| def parallelize_llama( |
| model: nn.Module, |
| world_mesh: DeviceMesh, |
| parallel_dims: ParallelDims, |
| job_config: JobConfig, |
| ): |
| """ |
| Apply tensor parallelism, activation checkpointing, torch.compile, and data |
| parallelism to the model. |
| |
| NOTE: The passed-in model preferably should be on meta device. Otherwise, |
| the model must fit on GPU or CPU memory. |
| """ |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| if job_config.activation_checkpoint.mode != "none": |
| apply_ac(model, job_config.activation_checkpoint) |
|
|
| |
| if ( |
| parallel_dims.dp_replicate_enabled |
| or parallel_dims.dp_shard_enabled |
| or parallel_dims.cp_enabled |
| ): |
| if parallel_dims.dp_replicate_enabled: |
| if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: |
| dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") |
| dp_mode = "hybrid_shard" |
| else: |
| dp_mesh_dim_names = ("dp_replicate",) |
| dp_mode = "replicate" |
| else: |
| dp_mesh_dim_names = ("dp_shard_cp",) |
| dp_mode = "fully_shard" |
|
|
| mp_policy = MixedPrecisionPolicy( |
| param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
| reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
| ) |
|
|
| model = data_parallel( |
| model, |
| world_mesh[tuple(dp_mesh_dim_names)], |
| mode=dp_mode, |
| ac_mode=job_config.activation_checkpoint.mode, |
| mp_policy=mp_policy, |
| ) |
| logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode) |
|
|
| if job_config.training.compile: |
| torch._inductor.config.reorder_for_peak_memory = False |
| model = torch.compile(model, fullgraph=True) |
|
|
| return model |
|
|