Buckets:
Callbacks
RichProgressCallback[[trl.RichProgressCallback]]
trl.RichProgressCallback[[trl.RichProgressCallback]]
A TrainerCallback that displays the progress of training or evaluation using Rich.
LogCompletionsCallback[[trl.LogCompletionsCallback]]
trl.LogCompletionsCallback[[trl.LogCompletionsCallback]]
A TrainerCallback that logs completions to Weights & Biases and/or Comet.
Usage:
trainer = DPOTrainer(...)
completions_callback = LogCompletionsCallback(trainer=trainer)
trainer.add_callback(completions_callback)
Parameters:
trainer (Trainer) : Trainer to which the callback will be attached. The trainer's evaluation dataset must include a "prompt" column containing the prompts for generating completions.
generation_config (GenerationConfig, optional) : The generation config to use for generating completions.
num_prompts (int, optional) : The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset.
freq (int, optional) : The frequency at which to log completions. If not provided, defaults to the trainer's eval_steps.
BEMACallback[[trl.BEMACallback]]
trl.BEMACallback[[trl.BEMACallback]]
A TrainerCallback that implements BEMA (Bias-Corrected Exponential Moving Average) by Adam Block and Cyril Zhang. Code from https://github.com/abblock/bema under MIT license.
BEMA computes model weights that scale like:
where is the current model weights, is a snapshot of the model weights at the
first update_after step, is the exponential moving average of the model weights, and is a scaling factor that decays with the number of steps as
The EMA is computed as:
where is a decay factor that decays with the number of steps as
Example:
from trl import BEMACallback
trainer = Trainer(..., callbacks=[BEMACallback()])
Parameters:
update_freq (int, optional, defaults to 400) : Update the BEMA weights every X steps. Denoted this as in the paper.
ema_power (float, optional, defaults to 0.5) : Power for the EMA decay factor. Denoted in the paper. To disable EMA, set this to 0.0.
bias_power (float, optional, defaults to 0.2) : Power for the BEMA scaling factor. Denoted in the paper. To disable BEMA, set this to 0.0.
lag (int, optional, defaults to 10) : Initial offset in the weight decay schedule that controls early-stage smoothness by acting as a virtual starting age for the updates. Denoted as in the paper.
update_after (int, optional, defaults to 0) : Burn-in time before starting to update the BEMA weights. Denoted in the paper.
multiplier (float, optional, defaults to 1.0) : Initial value for the EMA decay factor. Denoted as in the paper.
min_ema_multiplier (float, optional, defaults to 0.0) : Minimum value for the EMA decay factor.
device (str, optional, defaults to "cpu") : Device to use for the BEMA buffers, e.g. "cpu" or "cuda". Note that in most cases, this device SHOULD BE DIFFERENT from the device used for training in order to avoid OOM.
WeaveCallback[[trl.WeaveCallback]]
trl.WeaveCallback[[trl.WeaveCallback]]
A TrainerCallback that logs traces and evaluations to W&B Weave. The callback uses https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation step.
Supports two modes based on the scorers parameter:
- Tracing Mode (when scorers=None): Logs predictions for data exploration and analysis
- Evaluation Mode (when scorers provided): Logs predictions with scoring and summary metrics
Both modes use Weave's EvaluationLogger for structured, consistent data logging.
The callback logs data during evaluation phases (on_evaluate) rather than training steps, making it more
efficient and semantically correct. It gracefully handles missing weave installation by logging warnings and
skipping weave-specific functionality. It also checks for existing weave clients before initializing new ones.
Usage:
# Tracing mode (just log predictions)
trainer = DPOTrainer(...)
weave_callback = WeaveTraceCallback(trainer=trainer) # project_name optional
trainer.add_callback(weave_callback)
# Or specify a project name
weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training")
trainer.add_callback(weave_callback)
# Evaluation mode (log predictions + scores + summary)
def accuracy_scorer(prompt: str, completion: str) -> float:
# Your scoring logic here (metadata available via eval_attributes)
return score
weave_callback = WeaveTraceCallback(
trainer=trainer,
project_name="my-llm-training", # optional and needed only if weave client is not initialized
scorers={"accuracy": accuracy_scorer},
)
trainer.add_callback(weave_callback)
on_train_begintrl.WeaveCallback.on_train_beginhttps://github.com/huggingface/trl/blob/vr_5607/trl/trainer/callbacks.py#L477[{"name": "args", "val": ""}, {"name": "state", "val": ""}, {"name": "control", "val": ""}, {"name": "**kwargs", "val": ""}] Initialize Weave when training begins.
Parameters:
trainer (Trainer) : Trainer to which the callback will be attached. The trainer's evaluation dataset must include a "prompt" column containing the prompts for generating completions.
project_name (str, optional) : Name of the Weave project where data will be logged. If not provided, will try to use existing weave client or fall back to the active wandb run's project name. Raises an error if none of these are available.
scorers (dict[str, Callable], optional) : Dictionary mapping scorer names to scorer functions. If None, operates in tracing mode (predictions only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should have signature: scorer(prompt: str, completion: str) -> float | int
generation_config (GenerationConfig, optional) : Generation config to use for generating completions.
num_prompts (int or None, optional) : Number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset.
dataset_name (str, optional, defaults to "eval_dataset") : Name for the dataset metadata in Weave.
model_name (str, optional) : Name for the model metadata in Weave. If not provided, attempts to extract from model config.
Xet Storage Details
- Size:
- 7.88 kB
- Xet hash:
- 596178ac1e5ada97f3c7576e7515cd960f088855931ea19d706b524c9472415d
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.