BEMA for Reference Model
This feature implements the BEMA algorithm to update the reference model during DPO training.
Usage
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
bema_callback = BEMACallback(update_ref_model=True)
trainer = DPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
train_dataset=dataset,
callbacks=[bema_callback],
)
trainer.train()
DPOTrainer
[[autodoc]] experimental.bema_for_ref_model.DPOTrainer - train - save_model - push_to_hub
BEMACallback
[[autodoc]] experimental.bema_for_ref_model.BEMACallback