File size: 3,329 Bytes
fcca8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
---
title: Sequence Parallelism
description: Train with long sequences split across multiple GPUs.
---

# Sequence Parallelism

Sequence parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern.

## When to Use Sequence Parallelism

Use sequence parallelism when:

- You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available
- You're experiencing OOM (Out Of Memory) errors with long sequences

## Configuration

To enable sequence parallelism, add the following to your configuration file:

```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4  # Split sequences across 4 GPUs
```

The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:

- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4

## Implementation Details

When sequence parallelism is enabled:

1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
4. The trainer uses special ring communication patterns for attention operations

## Requirements

To use sequence parallelism, you need:

- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with:
  - `pip install axolotl[ring-flash-attn]` (preferred)
  - `pip install ring-flash-attn>=0.1.4`

## Limitations

- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- May have a small performance overhead due to communication between GPUs

## Example

```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
sequence_parallel_degree: 2  # Split each sequence into 4 parts
flash_attention: true  # Required with sequence parallelism
...
```

This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.

## Sample Packing with Sequence Parallelism

Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:

1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the sequence parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions

## Effect on Batch Size

When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:

- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases

For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4