File size: 781 Bytes
1fa3c6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# BEMA for Reference Model

This feature implements the BEMA algorithm to update the reference model during DPO training.

## Usage

```python

from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer

from datasets import load_dataset



dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")



bema_callback = BEMACallback(update_ref_model=True)



trainer = DPOTrainer(

    model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",

    train_dataset=dataset,

    callbacks=[bema_callback],

)

trainer.train()

```

## DPOTrainer

[[autodoc]] experimental.bema_for_ref_model.DPOTrainer

    - train

    - save_model
    - push_to_hub

## BEMACallback

[[autodoc]] experimental.bema_for_ref_model.BEMACallback