Buckets:

hf-doc-build/doc-dev / trl /pr_4331 /en /bema_for_reference_model.md
rtrm's picture
|
download
raw
1.08 kB

BEMA for Reference Model

This feature implements the BEMA algorithm to update the reference model during DPO training.

Usage

from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer


pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")

bema_callback = BEMACallback(update_ref_model=True)

model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token

trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    train_dataset=pref_dataset,
    processing_class=tokenizer,
    callbacks=[bema_callback],
)

trainer.train()

Xet Storage Details

Size:
1.08 kB
·
Xet hash:
850342b8b8af10bad91d79f59be787319557d6bca9395cf7037a4c50d2df16c5

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.