Buckets:
BEMA for Reference Model
This feature implements the BEMA algorithm to update the reference model during DPO training.
Usage
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
bema_callback = BEMACallback(update_ref_model=True)
trainer = DPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
train_dataset=dataset,
callbacks=[bema_callback],
)
trainer.train()
DPOTrainer[[trl.DPOTrainer]]
trl.DPOTrainer[[trl.DPOTrainer]]
traintrl.DPOTrainer.trainhttps://github.com/huggingface/trl/blob/vr_5607/transformers/trainer.py#L1323[{"name": "resume_from_checkpoint", "val": ": str | bool | None = None"}, {"name": "trial", "val": ": optuna.Trial | dict[str, Any] | None = None"}, {"name": "ignore_keys_for_eval", "val": ": list[str] | None = None"}]- resume_from_checkpoint (str or bool, optional) --
If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a
bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance
of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
- trial (
optuna.Trialordict[str, Any], optional) -- The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) -- A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.0~trainer_utils.TrainOutputObject containing the global step count, training loss, and metrics.
Main training entry point.
Parameters:
resume_from_checkpoint (str or bool, optional) : If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
trial (optuna.Trial or dict[str, Any], optional) : The trial run or the hyperparameter dictionary for hyperparameter search.
ignore_keys_for_eval (list[str], optional) : A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
Returns:
~trainer_utils.TrainOutput
Object containing the global step count, training loss, and metrics.
save_model[[trl.DPOTrainer.save_model]]
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub[[trl.DPOTrainer.push_to_hub]]
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.
Parameters:
commit_message (str, optional, defaults to "End of training") : Message to commit while pushing.
blocking (bool, optional, defaults to True) : Whether the function should return only when the git push has finished.
token (str, optional, defaults to None) : Token with write permission to overwrite Trainer's original args.
revision (str, optional) : The git revision to commit from. Defaults to the head of the "main" branch.
kwargs (dict[str, Any], optional) : Additional keyword arguments passed along to ~Trainer.create_model_card.
Returns:
The URL of the repository where the model was pushed if blocking=False, or a Future object tracking the
progress of the commit if blocking=True.
BEMACallback[[trl.BEMACallback]]
trl.BEMACallback[[trl.BEMACallback]]
A TrainerCallback that implements BEMA (Bias-Corrected Exponential Moving Average) by Adam Block and Cyril Zhang. Code from https://github.com/abblock/bema under MIT license.
BEMA computes model weights that scale like:
where is the current model weights, is a snapshot of the model weights at the
first update_after step, is the exponential moving average of the model weights, and is a scaling factor that decays with the number of steps as
The EMA is computed as:
where is a decay factor that decays with the number of steps as
Example:
from trl import BEMACallback
trainer = Trainer(..., callbacks=[BEMACallback()])
Parameters:
update_freq (int, optional, defaults to 400) : Update the BEMA weights every X steps. Denoted this as in the paper.
ema_power (float, optional, defaults to 0.5) : Power for the EMA decay factor. Denoted in the paper. To disable EMA, set this to 0.0.
bias_power (float, optional, defaults to 0.2) : Power for the BEMA scaling factor. Denoted in the paper. To disable BEMA, set this to 0.0.
lag (int, optional, defaults to 10) : Initial offset in the weight decay schedule that controls early-stage smoothness by acting as a virtual starting age for the updates. Denoted as in the paper.
update_after (int, optional, defaults to 0) : Burn-in time before starting to update the BEMA weights. Denoted in the paper.
multiplier (float, optional, defaults to 1.0) : Initial value for the EMA decay factor. Denoted as in the paper.
min_ema_multiplier (float, optional, defaults to 0.0) : Minimum value for the EMA decay factor.
device (str, optional, defaults to "cpu") : Device to use for the BEMA buffers, e.g. "cpu" or "cuda". Note that in most cases, this device SHOULD BE DIFFERENT from the device used for training in order to avoid OOM.
update_ref_model (bool, optional, defaults to False) : Whether to update the reference model with BEMA weights. This creates a lagged, smoothed version of the main model as the reference model.
ref_model_update_freq (int, optional, defaults to 400) : Update the reference model with BEMA weights every this many steps.
ref_model_update_after (int, optional, defaults to 0) : Number of steps to wait before starting to update the reference model.
Xet Storage Details
- Size:
- 7.18 kB
- Xet hash:
- d97b3bdfb0373238ec2e262764b0034da3660bcaa16860a2f3a6fdabb8526577
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.