File size: 13,572 Bytes
1fa3c6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
# Reducing Memory Usage

Training workflows can often be optimized to **reduce memory consumption**, and TRL provides several built-in features to help achieve this.

Below, we outline these techniques and recommend experimenting with different combinations to figure out which configuration works best for your specific setup.

Each method includes examples for the supported trainers. If you're unsure whether a technique is compatible with your trainer, please take a look at the corresponding trainer documentation.

For additional strategies, such as **gradient checkpointing**, which is supported across all trainers, see the [`transformers` performance guide](https://huggingface.co/docs/transformers/perf_train_gpu_one#gradient-checkpointing).

## Truncation

Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.

![Truncation prompt-completion](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png)

To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.

<hfoptions id="truncation">
<hfoption id="DPO">

DPO truncation is controlled via `max_length`, which truncates the combined prompt+completion sequence.

![DPO truncation](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_prompt_completion.png)

To set the truncation parameter, use the following code snippet:

```python

from trl import DPOConfig



training_args = DPOConfig(..., max_length=...)

```

> [!WARNING]
> The legacy `max_prompt_length` and `max_completion_length` parameters are now removed; instead, filter or pre-truncate overlong prompts/completions in your dataset before training.

</hfoption>
<hfoption id="SFT">

SFT truncation is applied to the input sequence via the `max_length` parameter.

![Truncation input ids](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png)

To set the truncation parameter, use the following code snippet:

```python

from trl import SFTConfig



training_args = SFTConfig(..., max_length=...)

```

</hfoption>
</hfoptions>

### How to choose the `max_length` value?



If `max_length` is too small, a significant portion of your tokens will be discarded and won't contribute to training. If it's too large, memory usage can spike, potentially leading to out-of-memory (OOM) errors. Without packing or padding-free, a large `max_length` may also result in inefficient training, as many tokens will be padding.



To help you choose an appropriate value, we provide a utility to visualize the sequence length distribution in your dataset.



<iframe src="https://trl-lib-dataset-length-profiler.hf.space" frameborder="0" width="100%" height="1000"></iframe>



## Packing



> [!TIP]

> This technique is available only for **SFT** training and setups that use **FlashAttention** (or its variants).



[Truncation](#truncation) has several drawbacks:



1. **Loss of information**: Important tokens at the end of sequences may be discarded.

2. **Choosing truncation length**: Too short loses data; too long reduces efficiency.



Packing mitigates these issues by grouping multiple sequences into the same training row, filling each row up to `max_length`.

![Packing](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/packing_3.png)

TRL implements packing using **Best-Fit Decreasing (BFD)** bin packing, which groups sequences efficiently while minimizing padding. When a sequence exceeds `max_length`, different strategies determine how the overflow tokens are handled.

TRL supports three strategies:

* `"bfd"` (default): Uses **Best-Fit Decreasing packing**. If a sequence exceeds `max_length`, the overflow tokens are discarded.

* `"bfd_split"`: Uses **Best-Fit Decreasing packing**, but long sequences are split into chunks ≤ `max_length` before packing. This preserves all tokens and follows the approach proposed in [Fewer Truncations Improve Language Modeling](https://huggingface.co/papers/2404.10830).

* `"wrapped"`: All tokens are concatenated into a stream and split into fixed-length blocks. This minimizes padding but may mix unrelated examples. This strategy corresponds to the *concatenate-then-split* preprocessing described in the literature (e.g., [Fewer Truncations Improve Language Modeling](https://huggingface.co/papers/2404.10830)). It has the downside of breaking sequence continuity for a large fraction of the dataset, which hurts performance, as discussed in the [Qwen3-Coder-Next Technical Report](https://huggingface.co/papers/2603.00729).

> [!NOTE]
> If all sequences are shorter than `max_length`, **`bfd` and `bfd_split` behave identically**, since no truncation or splitting is required.



```python

from trl import SFTConfig



training_args = SFTConfig(
    ...,

    packing=True,

    packing_strategy="bfd",

    max_length=512,

)

```


## PEFT for parameter-efficient fine-tuning

Parameter-Efficient Fine-Tuning (PEFT) methods like LoRA are among the most effective techniques for reducing memory usage during training. Instead of training all model parameters, PEFT methods train only a small number of adapter parameters, significantly reducing memory requirements and enabling fine-tuning of larger models on limited hardware.

For comprehensive details on using PEFT with TRL, including various adapter methods, quantization options, and advanced configurations, see [PEFT Integration](peft_integration).

To use PEFT for reducing memory usage:

```python

from datasets import load_dataset

from peft import LoraConfig

from trl import SFTTrainer



dataset = load_dataset("trl-lib/Capybara", split="train")



peft_config = LoraConfig()



trainer = SFTTrainer(

    model="Qwen/Qwen2.5-0.5B",

    train_dataset=dataset,

    peft_config=peft_config,

)

```

PEFT can be combined with other memory reduction techniques such as quantization (4-bit or 8-bit) for even greater memory savings. See [PEFT Integration](peft_integration) for quantization examples.

## Liger for reducing peak memory usage

[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%.

For more information, see [Liger Kernel Integration](liger_kernel_integration).

To use Liger for reducing peak memory usage, use the following code snippet:

<hfoptions id="liger">
<hfoption id="SFT">

```python

from trl import SFTConfig



training_args = SFTConfig(..., use_liger_kernel=True)

```

</hfoption>
<hfoption id="DPO">

```python

from trl import DPOConfig



training_args = DPOConfig(..., use_liger_kernel=True)

```

</hfoption>
<hfoption id="GRPO">

```python

from trl import GRPOConfig



training_args = GRPOConfig(..., use_liger_kernel=True)

```

</hfoption>
<hfoption id="KTO">

```python

from trl.experimental.kto import KTOConfig



training_args = KTOConfig(..., use_liger_kernel=True)

```

</hfoption>
<hfoption id="GKD">

```python

from trl.experimental.gkd import GKDConfig



training_args = GKDConfig(..., use_liger_kernel=True)

```

</hfoption>
</hfoptions>

## Padding-free

Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.

![Padding-free](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/padding-free.png)

> [!WARNING]
> It's highly recommended to use padding-free batching with **FlashAttention 2** or **FlashAttention 3**. Otherwise, you may encounter batch contamination issues.

<hfoptions id="padding-free">
<hfoption id="DPO">

```python

from trl import DPOConfig



training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"})

```

</hfoption>
<hfoption id="SFT">

```python

from trl import SFTConfig



training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"})

```

</hfoption>
</hfoptions>

## Activation offloading

Activation offloading is a memory efficiency technique that reduces GPU VRAM usage by temporarily moving activation tensors to CPU RAM during the forward pass and bringing them back only when needed for the backward pass. This significantly reduces peak memory usage at the cost of slightly increased training time.

To enable activation offloading in your SFT training configuration:

```python

from trl import SFTConfig



training_args = SFTConfig(..., activation_offloading=True)

```

Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors that would be inefficient. For performance optimization, it can, via a flag (which is true by default), use CUDA streams to overlap computation with CPU-GPU transfers.

## Padding Sequences to a Multiple

> [!TIP]
> This technique is supported for **SFT** and **Reward** trainers currently.

When enabled, this option ensures that all sequences are **padded to a multiple** of the specified value.  
This can improve computational efficiency on some hardware by aligning sequence lengths to memory-friendly boundaries.

<hfoptions id="pad_to_multiple_of">
<hfoption id="SFT">

```python

from trl import SFTConfig



training_args = SFTConfig(..., pad_to_multiple_of=2048)

```

</hfoption>
<hfoption id="Reward">

```python

from trl import RewardConfig



training_args = RewardConfig(..., pad_to_multiple_of=2048)

```

</hfoption>
</hfoptions>

## Disabling model gathering for generation in online methods

When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to OOM errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).

If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter:

<hfoptions id="ds3_gather_for_generation">
<hfoption id="GRPO">

```python

from trl import GRPOConfig



training_args = GRPOConfig(..., ds3_gather_for_generation=False)

```

</hfoption>
<hfoption id="Online DPO">

```python

from trl.experimental.online_dpo import OnlineDPOConfig



training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False)

```

</hfoption>
<hfoption id="PPO">

```python

from trl.experimental.ppo import PPOConfig



training_args = PPOConfig(..., ds3_gather_for_generation=False)

```

</hfoption>
<hfoption id="RLOO">

```python

from trl import RLOOConfig



training_args = RLOOConfig(..., ds3_gather_for_generation=False)

```

</hfoption>
</hfoptions>

This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds.

## vLLM sleep mode

When using **vLLM** as the generation backend for online training methods, you can enable _sleep mode_ to offload vLLM parameters and cache to CPU RAM during the optimization step and reload them back to GPU VRAM when needed for weight synchronization and generation.

<hfoptions id="vllm_sleep">
<hfoption id="GRPO">

```python

from trl import GRPOConfig



training_args = GRPOConfig(..., vllm_enable_sleep_mode=True)

```

</hfoption>
<hfoption id="RLOO">

```python

from trl import RLOOConfig



training_args = RLOOConfig(..., vllm_enable_sleep_mode=True)

```

</hfoption>
</hfoptions>

Offloading the vLLM weights and cache helps keep GPU memory usage low, which can be particularly beneficial when training large models or using limited GPU resources. However, waking the vLLM engine from sleep mode introduces some host–device transfer latency, which may slightly impact training speed.

## Gradient checkpointing

Gradient checkpointing trades compute for memory by not storing all intermediate activations during the forward pass, recomputing them during the backward pass instead.

```python

from trl import SFTConfig



training_args = SFTConfig(..., gradient_checkpointing=True)

```

> [!NOTE]
> Gradient checkpointing is enabled by default in all trainers to optimize memory usage. You can disable it by setting `gradient_checkpointing=False` if needed.



For more memory optimization techniques, see the [Transformers Performance Guide](https://huggingface.co/docs/transformers/perf_train_gpu_one#gradient-checkpointing).