Buckets:
| # Distributed Training with `optimum-neuron` | |
| AWS Trainium instances provide powerful infrastructure for training large language models at scale. A `trn1.32xlarge` instance contains 16 Neuron devices with 32 cores total, offering 512GB of memory (16GB per core). | |
| However, training large models presents a fundamental challenge: by default, each Neuron core operates as an independent data-parallel worker, requiring the entire model, gradients, and optimizer state (approximately 4× the model size) to fit within a single core's 16GB memory limit, with additional space needed for activations. | |
| For models that exceed these memory constraints, `optimum-neuron` provides sophisticated parallelism strategies that distribute computation and memory across multiple devices, enabling you to train models that would be impossible to fit on individual cores: | |
| ## Parallelism Strategies Overview | |
| ### 1. ZeRO-1 (Optimizer State Sharding) | |
| [ZeRO-1](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/zero1_gpt2.html) is an optimizer-level optimization that reduces memory usage without changing your model architecture. | |
| **How it works**: Shards the optimizer state (gradients, momentum, variance) across data-parallel ranks instead of replicating it on each device. | |
| **Memory savings**: Reduces optimizer memory usage by `1/data_parellel_size`. | |
| **When to use**: Always beneficial when training with multiple devices, regardless of model size. | |
| ### 2. Tensor Parallelism (Intra-layer Model Parallelism) | |
| [Tensor Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html) splits individual model layers across multiple devices. | |
| **How it works**: Shards matrix multiplications (linear layers, attention) along rows or columns across devices. Each device computes part of each layer, requiring communication between devices for each forward/backward pass. | |
| **Memory savings**: Reduces model parameter memory by `1/tensor_parallel_size`. | |
| **When to use**: When your model is too large to fit on a single device, even after applying ZeRO-1. | |
| **Typical deployment**: Usually applied within a single node (intra-node) due to high communication requirements. | |
| **Trade-offs**: Increases communication overhead between devices, which can slow down training if overused. | |
| ### 3. Sequence Parallelism (Activation Sharding) | |
| [Sequence parallelism](https://arxiv.org/pdf/2205.05198.pdf) is an optimization that works alongside Tensor Parallelism to further reduce memory usage. | |
| **How it works**: Shards activations along the sequence dimension in regions where tensors are not already sharded by tensor parallelism. | |
| **Memory savings**: Reduces activation memory proportional to sequence length, especially beneficial for long sequences. | |
| **When to use**: Always enable when using tensor parallelism - it provides additional memory savings with minimal overhead. | |
| **Requirement**: Only works in combination with tensor parallelism. | |
| ### 4. Pipeline Parallelism (Inter-layer Model Parallelism) | |
| [Pipeline Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html) splits model layers across different devices. | |
| **How it works**: Divides your model into stages, with each stage containing consecutive layers running on different devices. Uses microbatching to keep all devices busy. | |
| **Memory savings**: Reduces model parameter memory by `1/pipeline_parallel_size`. | |
| **When to use**: For very large models that don't fit even with tensor parallelism, or when you want to scale across many devices with less communication overhead than tensor parallelism. | |
| **Typical deployment**: Usually applied across multiple nodes (inter-node) to scale to larger numbers of devices while minimizing high-bandwidth communication requirements. | |
| **Trade-offs**: Introduces pipeline bubbles (idle time) and requires careful tuning of microbatch sizes. | |
| The good news is that it is possible to combine those techniques, and `optimum-neuron` makes it very easy! | |
| All the training examples in the optimum-neuron repo use these parallelism features via the `NeuronTrainer`. | |
| ## How to enable ZeRO-1? | |
| ZeRO-1 can be enabled either through the `NeuronTrainer` or directly with the `NeuronAccelerator`. | |
| ### Via the `NeuronTrainer` | |
| ```python | |
| from optimum.neuron import NeuronTrainingArguments, NeuronTrainer | |
| # Enable ZeRO-1 in the training arguments | |
| training_args = NeuronTrainingArguments( | |
| output_dir="./output", | |
| per_device_train_batch_size=1, | |
| zero_1=True, # Enable ZeRO-1 | |
| bf16=True, | |
| # ... other training arguments | |
| ) | |
| trainer = NeuronTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| ) | |
| trainer.train() | |
| ``` | |
| Since the example scripts use the `NeuronTrainer`, you can enable ZeRO-1 when using them by adding the `--zero_1` flag to your command line. | |
| For example: | |
| ```bash | |
| torchrun --nproc_per_node=2 examples/training/qwen3/finetune_qwen3.py \ | |
| --model_name_or_path Qwen/Qwen2.5-0.5B \ | |
| --dataset_name wikitext \ | |
| --dataset_config_name wikitext-2-raw-v1 \ | |
| --do_train \ | |
| --per_device_train_batch_size 1 \ | |
| --block_size 1024 \ | |
| --bf16 \ | |
| --zero_1 \ | |
| --tensor_parallel_size 2 \ | |
| --output_dir my_training/ | |
| ``` | |
| ### Via the `NeuronAccelerator` | |
| When using the `NeuronAccelerator` directly, you need to create a `TrainingNeuronConfig` and enable ZeRO-1 separately: | |
| ```python | |
| from torch.optim import AdamW | |
| from optimum.neuron import NeuronAccelerator | |
| from optimum.neuron.models.training.config import TrainingNeuronConfig | |
| # Create the training configuration | |
| trn_config = TrainingNeuronConfig() | |
| # Create accelerator with ZeRO-1 enabled | |
| accelerator = NeuronAccelerator( | |
| trn_config=trn_config, | |
| zero_1=True, # Enable ZeRO-1 | |
| mixed_precision="bf16", | |
| ) | |
| model = ... # Your model instance | |
| optimizer = AdamW(model.parameters(), lr=5e-5) | |
| # Prepare model and optimizer | |
| model, optimizer = accelerator.prepare(model, optimizer) | |
| ``` | |
| ## How to enable Tensor Parallelism? | |
| Tensor Parallelism can be used with either the `NeuronTrainer` or `NeuronAccelerator`. | |
| **Important**: Tensor parallelism requires models that have a custom modeling implementation in `optimum.neuron.models.training`. | |
| When doing Tensor Parallelism, you have several important settings: | |
| 1. The `tensor_parallel_size`: Ideally it should be the smallest value for which the model fits in memory. | |
| 2. Whether sequence parallelism should be enabled: [Sequence parallelism](https://arxiv.org/pdf/2205.05198.pdf) shards the activations on the sequence axis outside of the tensor parallel regions, saving memory by sharding the activations. | |
| When using distributed training, the training script is called by `torchrun`, which will dispatch it to workers, one worker per core. Each worker will load the sharded model and dispatch the parameters automatically across the cores. The `tensor_parallel_size` is the number of workers to shard the model parameters on. | |
| ### Via the `NeuronTrainer` | |
| ```python | |
| from optimum.neuron import NeuronTrainingArguments, NeuronTrainer | |
| # Configure tensor parallelism in training arguments | |
| training_args = NeuronTrainingArguments( | |
| output_dir="./output", | |
| per_device_train_batch_size=1, | |
| bf16=True, | |
| tensor_parallel_size=8, | |
| # ... other training arguments | |
| ) | |
| trainer = NeuronTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| ) | |
| trainer.train() | |
| ``` | |
| Since the example scripts use the `NeuronTrainer`, you can enable Tensor Parallelism when using them by specifying the `--tensor_parallel_size` argument. | |
| For example: | |
| ```bash | |
| torchrun --nproc_per_node=8 examples/training/qwen3/finetune_qwen3.py \ | |
| --model_name_or_path Qwen/Qwen2.5-0.5B \ | |
| --dataset_name wikitext \ | |
| --dataset_config_name wikitext-2-raw-v1 \ | |
| --do_train \ | |
| --per_device_train_batch_size 1 \ | |
| --block_size 1024 \ | |
| --bf16 \ | |
| --tensor_parallel_size 8 \ | |
| --output_dir my_training/ | |
| ``` | |
| ### Via the `NeuronAccelerator` | |
| When using the `NeuronAccelerator` directly, you configure tensor parallelism through the `TrainingNeuronConfig`: | |
| ```python | |
| from torch.optim import AdamW | |
| from optimum.neuron import NeuronAccelerator | |
| from optimum.neuron.models.training.config import TrainingNeuronConfig | |
| # Configure tensor parallelism | |
| trn_config = TrainingNeuronConfig( | |
| tensor_parallel_size=8, | |
| sequence_parallel_enabled=True, | |
| checkpoint_dir=None, # Can be specified when resuming from checkpoint | |
| ) | |
| accelerator = NeuronAccelerator( | |
| trn_config=trn_config, | |
| mixed_precision="bf16", | |
| ) | |
| model = ... # Your model instance | |
| optimizer = AdamW(model.parameters(), lr=5e-5) | |
| model, optimizer = accelerator.prepare(model, optimizer) | |
| ``` | |
| ## How to enable Pipeline Parallelism? | |
| Pipeline Parallelism allows you to split your model layers across multiple devices, enabling training of very large models that wouldn't fit on a single device, or even a signle node. | |
| **Important**: Pipeline parallelism requires models that have a custom modeling implementation in `optimum.neuron.models.training` and declare `SUPPORTS_PIPELINE_PARALLELISM = True`. | |
| ### Configuration Options | |
| Pipeline parallelism has several configuration parameters: | |
| - `pipeline_parallel_size`: Number of pipeline stages (devices to split layers across) | |
| - `pipeline_parallel_num_microbatches`: Number of microbatches for pipeline scheduling | |
| - When pipeline parallelism is enabled, ZeRO-1 can be automatically applied to the pipeline parallel optimizer | |
| ### Via the `NeuronTrainer` | |
| ```python | |
| from optimum.neuron import NeuronTrainingArguments, NeuronTrainer | |
| from optimum.neuron.models.training import LlamaForCausalLM # Custom model implementation | |
| # Configure pipeline parallelism in training arguments | |
| training_args = NeuronTrainingArguments( | |
| output_dir="./output", | |
| per_device_train_batch_size=4, # Will be split into microbatches | |
| bf16=True, | |
| tensor_parallel_size=2, | |
| pipeline_parallel_size=4, # Split model across 4 pipeline stages | |
| pipeline_parallel_num_microbatches=4, # Number of microbatches | |
| zero_1=True, # Enable ZeRO-1 with pipeline parallelism | |
| # ... other training arguments | |
| ) | |
| # Load model using custom implementation - must be done with the model class directly | |
| model = LlamaForCausalLM.from_pretrained( | |
| "meta-llama/Llama-3.2-3B", | |
| trn_config=training_args.trn_config # Pass the auto-generated trn_config | |
| ) | |
| trainer = NeuronTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| ) | |
| trainer.train() | |
| ``` | |
| ### Via the `NeuronAccelerator` | |
| ```python | |
| from optimum.neuron import NeuronAccelerator | |
| from optimum.neuron.models.training.config import TrainingNeuronConfig | |
| from optimum.neuron.models.training import LlamaForCausalLM | |
| from torch.optim import AdamW | |
| # Configure combined parallelism strategies | |
| trn_config = TrainingNeuronConfig( | |
| tensor_parallel_size=2, | |
| pipeline_parallel_size=4, | |
| pipeline_parallel_num_microbatches=4, | |
| sequence_parallel_enabled=True, | |
| ) | |
| accelerator = NeuronAccelerator( | |
| trn_config=trn_config, | |
| zero_1=True, # Can combine with ZeRO-1 | |
| mixed_precision="bf16", | |
| ) | |
| # Load model with custom implementation | |
| model = LlamaForCausalLM.from_pretrained( | |
| "meta-llama/Llama-3.2-3B", | |
| trn_config=trn_config | |
| ) | |
| optimizer = AdamW(model.parameters(), lr=5e-5) | |
| model, optimizer = accelerator.prepare(model, optimizer) | |
| ``` | |
| When using pipeline parallelism, the total number of processes should be at least `tensor_parallel_size * pipeline_parallel_size`. For example, with `tensor_parallel_size=2` and `pipeline_parallel_size=4`, you need 8 processes total. | |
| ## Combining Parallelism Strategies | |
| You can combine multiple parallelism strategies for maximum memory efficiency and performance. Here's an example with all strategies combined: | |
| ### Via the `NeuronTrainer` | |
| ```python | |
| from optimum.neuron import NeuronTrainingArguments, NeuronTrainer | |
| from optimum.neuron.models.training import LlamaForCausalLM | |
| # Example: Combine all parallelism strategies | |
| training_args = NeuronTrainingArguments( | |
| output_dir="./output", | |
| per_device_train_batch_size=32, | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| # ZeRO-1 | |
| zero_1=True, | |
| # Tensor parallelism | |
| tensor_parallel_size=4, | |
| disable_sequence_parallel=False, # Enable sequence parallelism | |
| # Pipeline parallelism | |
| pipeline_parallel_size=2, | |
| pipeline_parallel_num_microbatches=8, | |
| # Additional optimizations | |
| fuse_qkv=True, # Fuse QKV projections for efficiency | |
| kv_size_multiplier=None, # Auto-calculate optimal KV multiplier | |
| ) | |
| # Load model using custom implementation | |
| model = LlamaForCausalLM.from_pretrained( | |
| "meta-llama/Llama-3.2-3B", | |
| trn_config=training_args.trn_config | |
| ) | |
| trainer = NeuronTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| ) | |
| trainer.train() | |
| ``` | |
| This configuration uses 4 * 2 = 8 total processes: | |
| - Each tensor parallel group has 4 processes | |
| - Each pipeline stage runs on one tensor parallel group | |
| We can then run the training script on the `trn1.32xlarge` instance with 32 Neuron cores, resulting in the following configuration: `dp=4, tp=4, pp=2`, which means 4 data-parallel groups, each with 4 tensor-parallel devices, and 2 pipeline stages. | |
| ## Checkpoint consolidation | |
| Since distributed training uses sharded checkpoints across different workers, you need to consolidate them to create a standard model checkpoint that can be shared and used outside of the specific training configuration. | |
| The Optimum CLI provides a way of doing that very easily via the `optimum neuron consolidate` command: | |
| ```bash | |
| optimum-cli neuron consolidate --help | |
| usage: optimum-cli neuron consolidate [-h] [-f {pytorch,safetensors}] checkpoint_dir output_dir | |
| positional arguments: | |
| checkpoint_dir The path to the directory containing the checkpoints. | |
| output_dir The path to the output directory containing the consolidated checkpoint. | |
| optional arguments: | |
| -h, --help show this help message and exit | |
| -f {pytorch,safetensors}, --format {pytorch,safetensors} | |
| The format used to save the consolidated checkpoint. | |
| ``` | |
| All you need to do is specify the sharded checkpoints directory and the output directory that will contain the consolidated checkpoints, and the command takes care of the rest. | |
| It is also possible to specify the output format of the consolidated checkpoints. By default it will export them to the `safetensors` format, which is the recommended format to use. | |
| Example: | |
| Training with distributed parallelism just completed and the output dir is called `my_training`. The directory looks like the following: | |
| ```bash | |
| my_training/ | |
| ├── README.md | |
| ├── all_results.json | |
| ├── checkpoint-10 | |
| │ ├── config.json | |
| │ ├── scheduler.pt | |
| │ ├── special_tokens_map.json | |
| │ ├── shards/ | |
| │ ├── tokenizer.json | |
| │ ├── tokenizer.model | |
| │ ├── tokenizer_config.json | |
| │ ├── trainer_state.json | |
| │ └── training_args.bin | |
| ├── config.json | |
| ├── special_tokens_map.json | |
| ├── shards/ | |
| │ ├── tp_rank_00_pp_rank_00 | |
| │ ├── tp_rank_01_pp_rank_00 | |
| │ ├── tp_rank_02_pp_rank_00 | |
| │ ├── tp_rank_03_pp_rank_00 | |
| │ ├── tp_rank_00_pp_rank_01 | |
| │ ├── tp_rank_01_pp_rank_01 | |
| │ ├── tp_rank_02_pp_rank_01 | |
| │ └── tp_rank_03_pp_rank_01 | |
| ├── tokenizer.json | |
| ├── tokenizer.model | |
| ├── tokenizer_config.json | |
| ├── train_results.json | |
| ├── trainer_state.json | |
| ├── training_args.bin | |
| └── trn_config.json | |
| ``` | |
| You can consolidate the sharded checkpoints in `my_training/shards`, which correspond to the sharded checkpoints saved at the end of training, by running the following command: | |
| ```bash | |
| optimum-cli neuron consolidate my_training my_training_consolidated_checkpoint | |
| ``` | |
| The sharded checkpoints are saved under a directory called `shards`. The `optimum-cli neuron consolidate` command accepts as input both a directory that contains a `shards` directory, or the `shards` directory itself. | |
| ## Best Practices | |
| ### Choosing Parallelism Strategy | |
| 1. **Start with Tensor Parallelism**: Use the smallest `tensor_parallel_size` that fits your model in memory | |
| 2. **Add Pipeline Parallelism**: For very large models, combine with pipeline parallelism | |
| 3. **Enable Sequence Parallelism**: Always enable when using tensor parallelism for memory savings (set `disable_sequence_parallel=False`) | |
| 4. **Use ZeRO-1**: Combine with any parallelism strategy for optimizer memory savings | |
| ### Memory Optimization | |
| - Enable `gradient_checkpointing` for large models | |
| - Set appropriate `pipeline_parallel_num_microbatches` for pipeline parallelism | |
| ## Troubleshooting | |
| ### Common Issues | |
| 1. **Out of Memory**: Reduce batch size, increase parallelism, or enable gradient checkpointing | |
| 2. **Model Not Supported**: Ensure you're using a model from `optimum.neuron.models.training` | |
| 3. **Pipeline Parallelism Fails**: Check that the model supports pipeline parallelism | |
| 4. **Incorrect Process Count**: Ensure `nproc_per_node` matches your parallelism configuration | |
| ### Debugging Tips | |
| - Start with smaller models and parallelism sizes | |
| - Check that all processes can communicate properly | |
| - Verify checkpoint directories and permissions | |
| - Monitor Neuron device utilization |
Xet Storage Details
- Size:
- 17.7 kB
- Xet hash:
- 03be3edb1667a4fdf3d1e0ba8652fd26e6892bb79e1db9d4f34e1a34cd86361d
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.