trl-mcsd / docs /source /bema_for_reference_model.md
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# BEMA for Reference Model
This feature implements the BEMA algorithm to update the reference model during DPO training.
## Usage
```python
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
bema_callback = BEMACallback(update_ref_model=True)
trainer = DPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
train_dataset=dataset,
callbacks=[bema_callback],
)
trainer.train()
```
## DPOTrainer
[[autodoc]] experimental.bema_for_ref_model.DPOTrainer
- train
- save_model
- push_to_hub
## BEMACallback
[[autodoc]] experimental.bema_for_ref_model.BEMACallback