| --- |
| title: Sequence Parallelism |
| description: Train with long sequences split across multiple GPUs. |
| --- |
| |
| |
|
|
| Sequence parallelism is a technique that splits sequences across multiple GPUs, |
| allowing you to train with very long sequences that wouldn't fit on a single GPU. Each |
| GPU processes a different portion of the sequence, and the results are aggregated |
| through a ring communication pattern. |
|
|
| |
|
|
| Use sequence parallelism when: |
|
|
| - You need to train with sequence lengths that don't fit into a single GPU's memory |
| - You have multiple GPUs available |
| - You're experiencing OOM (Out Of Memory) errors with long sequences |
|
|
| |
|
|
| To enable sequence parallelism, add the following to your configuration file: |
|
|
| ```yaml |
| |
| sequence_parallel_degree: 4 |
| ``` |
|
|
| The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: |
|
|
| - With 8 GPUs, valid values would be 2, 4, or 8 |
| - With 4 GPUs, valid values would be 2 or 4 |
|
|
| |
|
|
| When sequence parallelism is enabled: |
|
|
| 1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group |
| 2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids |
| 3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences |
| 4. The trainer uses special ring communication patterns for attention operations |
|
|
| |
|
|
| To use sequence parallelism, you need: |
|
|
| - Multiple GPUs (at least 2) |
| - The `ring-flash-attn` package. Install with: |
| - `pip install axolotl[ring-flash-attn]` (preferred) |
| - `pip install ring-flash-attn>=0.1.4` |
|
|
| |
|
|
| - Flash attention must be enabled for this to work (`flash_attention: true` in config YAML) |
| - May have a small performance overhead due to communication between GPUs |
|
|
| |
|
|
| ```yaml |
| |
| base_model: meta-llama/Llama-3-8B-Instruct |
| sequence_len: 8192 |
| sequence_parallel_degree: 2 |
| flash_attention: true |
| ... |
| ``` |
|
|
| This will train the Llama 3 8B model with 8K context length, with each sequence split |
| into 2 subsequences of length 4096 across 2 GPUs. |
|
|
| |
|
|
| Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together: |
|
|
| 1. Samples are first packed together |
| 2. The packed sequences are then divided across GPUs in the sequence parallel group |
| 3. Position IDs are automatically adjusted to maintain proper relative positions |
|
|
| |
|
|
| When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: |
|
|
| - Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) |
| - The number of batches processed per step decreases |
|
|
| For example: |
| - With 8 GPUs and no sequence parallelism: 8 different batches processed per step |
| - With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) |
| - If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 |
|
|