| # BEMA for Reference Model | |
| This feature implements the BEMA algorithm to update the reference model during DPO training. | |
| ## Usage | |
| ```python | |
| from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer | |
| from datasets import load_dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") | |
| bema_callback = BEMACallback(update_ref_model=True) | |
| trainer = DPOTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| train_dataset=dataset, | |
| callbacks=[bema_callback], | |
| ) | |
| trainer.train() | |
| ``` | |
| ## DPOTrainer | |
| [[autodoc]] experimental.bema_for_ref_model.DPOTrainer | |
| - train | |
| - save_model | |
| - push_to_hub | |
| ## BEMACallback | |
| [[autodoc]] experimental.bema_for_ref_model.BEMACallback | |