Buckets:

hf-doc-build/doc-dev / trl /pr_4331 /en /callbacks.md
rtrm's picture
|
download
raw
15.7 kB
# Callbacks
## SyncRefModelCallback[[trl.SyncRefModelCallback]]
<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">
<docstring><name>class trl.SyncRefModelCallback</name><anchor>trl.SyncRefModelCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L104</source><parameters>[{"name": "ref_model", "val": ": transformers.modeling_utils.PreTrainedModel | torch.nn.modules.module.Module"}, {"name": "accelerator", "val": ": accelerate.accelerator.Accelerator | None"}]</parameters></docstring>
Callback to synchronize the model with a reference model.
</div>
## RichProgressCallback[[trl.RichProgressCallback]]
<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">
<docstring><name>class trl.RichProgressCallback</name><anchor>trl.RichProgressCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L145</source><parameters>[]</parameters></docstring>
A `TrainerCallback` that displays the progress of training or evaluation using Rich.
</div>
## WinRateCallback[[trl.WinRateCallback]]
<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">
<docstring><name>class trl.WinRateCallback</name><anchor>trl.WinRateCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L266</source><parameters>[{"name": "judge", "val": ": BasePairwiseJudge"}, {"name": "trainer", "val": ": Trainer"}, {"name": "generation_config", "val": ": transformers.generation.configuration_utils.GenerationConfig | None = None"}, {"name": "num_prompts", "val": ": int | None = None"}, {"name": "shuffle_order", "val": ": bool = True"}, {"name": "use_soft_judge", "val": ": bool = False"}]</parameters><paramsdesc>- **judge** ([BasePairwiseJudge](/docs/trl/pr_4331/en/judges#trl.BasePairwiseJudge)) --
The judge to use for comparing completions.
- **trainer** (`Trainer`) --
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
column containing the prompts for generating completions. If the `Trainer` has a reference model (via the
`ref_model` attribute), it will use this reference model for generating the reference completions;
otherwise, it defaults to using the initial model.
- **generation_config** ([GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig), *optional*) --
The generation config to use for generating completions.
- **num_prompts** (`int`, *optional*) --
The number of prompts to generate completions for. If not provided, defaults to the number of examples in
the evaluation dataset.
- **shuffle_order** (`bool`, *optional*, defaults to `True`) --
Whether to shuffle the order of the completions before judging.
- **use_soft_judge** (`bool`, *optional*, defaults to `False`) --
Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the
second.</paramsdesc><paramgroups>0</paramgroups></docstring>
A [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback) that computes the win rate of a model based on a reference.
It generates completions using prompts from the evaluation dataset and compares the trained model's outputs against
a reference. The reference is either the initial version of the model (before training) or the reference model, if
available in the trainer. During each evaluation step, a judge determines how often the trained model's completions
win against the reference using a judge. The win rate is then logged in the trainer's logs under the key
`"eval_win_rate"`.
<ExampleCodeBlock anchor="trl.WinRateCallback.example">
Usage:
```python
trainer = DPOTrainer(...)
judge = PairRMJudge()
win_rate_callback = WinRateCallback(judge=judge, trainer=trainer)
trainer.add_callback(win_rate_callback)
```
</ExampleCodeBlock>
</div>
## LogCompletionsCallback[[trl.LogCompletionsCallback]]
<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">
<docstring><name>class trl.LogCompletionsCallback</name><anchor>trl.LogCompletionsCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L458</source><parameters>[{"name": "trainer", "val": ": Trainer"}, {"name": "generation_config", "val": ": transformers.generation.configuration_utils.GenerationConfig | None = None"}, {"name": "num_prompts", "val": ": int | None = None"}, {"name": "freq", "val": ": int | None = None"}]</parameters><paramsdesc>- **trainer** (`Trainer`) --
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
column containing the prompts for generating completions.
- **generation_config** ([GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig), *optional*) --
The generation config to use for generating completions.
- **num_prompts** (`int`, *optional*) --
The number of prompts to generate completions for. If not provided, defaults to the number of examples in
the evaluation dataset.
- **freq** (`int`, *optional*) --
The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`.</paramsdesc><paramgroups>0</paramgroups></docstring>
A [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback) that logs completions to Weights & Biases and/or Comet.
<ExampleCodeBlock anchor="trl.LogCompletionsCallback.example">
Usage:
```python
trainer = DPOTrainer(...)
completions_callback = LogCompletionsCallback(trainer=trainer)
trainer.add_callback(completions_callback)
```
</ExampleCodeBlock>
</div>
## MergeModelCallback[[trl.MergeModelCallback]]
<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">
<docstring><name>class trl.MergeModelCallback</name><anchor>trl.MergeModelCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L779</source><parameters>[{"name": "merge_config", "val": ": MergeConfig | None = None"}, {"name": "merge_at_every_checkpoint", "val": ": bool = False"}, {"name": "push_to_hub", "val": ": bool = False"}]</parameters><paramsdesc>- **merge_config** (`MergeConfig`, *optional*) --
Configuration used for the merging process. If not provided, the default `MergeConfig` is used.
- **merge_at_every_checkpoint** (`bool`, *optional*, defaults to `False`) --
Whether to merge the model at every checkpoint.
- **push_to_hub** (`bool`, *optional*, defaults to `False`) --
Whether to push the merged model to the Hub after merging.</paramsdesc><paramgroups>0</paramgroups></docstring>
A [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback) that merges the policy model (the model being trained) with another model based
on a merge configuration.
<ExampleCodeBlock anchor="trl.MergeModelCallback.example">
Example:
```python
from trl.mergekit_utils import MergeConfig
from trl import MergeModelCallback
config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
```
</ExampleCodeBlock>
</div>
## BEMACallback[[trl.BEMACallback]]
<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">
<docstring><name>class trl.BEMACallback</name><anchor>trl.BEMACallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L840</source><parameters>[{"name": "update_freq", "val": ": int = 400"}, {"name": "ema_power", "val": ": float = 0.5"}, {"name": "bias_power", "val": ": float = 0.2"}, {"name": "lag", "val": ": int = 10"}, {"name": "update_after", "val": ": int = 0"}, {"name": "multiplier", "val": ": float = 1.0"}, {"name": "min_ema_multiplier", "val": ": float = 0.0"}, {"name": "device", "val": ": str = 'cpu'"}]</parameters><paramsdesc>- **update_freq** (`int`, *optional*, defaults to `400`) --
Update the BEMA weights every X steps. Denoted this as \\( \phi \\) in the paper.
- **ema_power** (`float`, *optional*, defaults to `0.5`) --
Power for the EMA decay factor. Denoted \\( \kappa \\) in the paper. To disable EMA, set this to `0.0`.
- **bias_power** (`float`, *optional*, defaults to `0.2`) --
Power for the BEMA scaling factor. Denoted \\( \eta \\) in the paper. To disable BEMA, set this to `0.0`.
- **lag** (`int`, *optional*, defaults to `10`) --
Initial offset in the weight decay schedule that controls early-stage smoothness by acting as a virtual
starting age for the updates. Denoted as \\( \rho \\) in the paper.
- **update_after** (`int`, *optional*, defaults to `0`) --
Burn-in time before starting to update the BEMA weights. Denoted \\( \tau \\) in the paper.
- **multiplier** (`float`, *optional*, defaults to `1.0`) --
Initial value for the EMA decay factor. Denoted as \\( \gamma \\) in the paper.
- **min_ema_multiplier** (`float`, *optional*, defaults to `0.0`) --
Minimum value for the EMA decay factor.
- **device** (`str`, *optional*, defaults to `"cpu"`) --
Device to use for the BEMA buffers, e.g. `"cpu"` or `"cuda"`. Note that in most cases, this device SHOULD
BE DIFFERENT from the device used for training in order to avoid OOM.</paramsdesc><paramgroups>0</paramgroups></docstring>
A [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback) that implements [BEMA](https://huggingface.co/papers/2508.00180)
(Bias-Corrected Exponential Moving Average) by [Adam Block](https://huggingface.co/abblock) and [Cyril
Zhang](https://huggingface.co/cyrilzhang). Code from https://github.com/abblock/bema under MIT license.
BEMA computes model weights that scale like:
$$
\theta_t' = \alpha_t \cdot (\theta_t - \theta_0) + \text{EMA}_t
$$
where \\( \theta_t \\) is the current model weights, \\( \theta_0 \\) is a snapshot of the model weights at the
first `update_after` step, \\( \text{EMA}_t \\) is the exponential moving average of the model weights, and
\\( \alpha_t \\) is a scaling factor that decays with the number of steps \\( t \\) as
$$
\alpha_t = (\rho + \gamma \cdot t)^{-\eta}.
$$
The EMA is computed as:
$$
\text{EMA}_t = (1 - \beta_t) \cdot \text{EMA}_{t-1} + \beta_t \cdot \theta_t
$$
where \\( \beta_t \\) is a decay factor that decays with the number of steps \\( t \\) as
$$
\beta_t = (\rho + \gamma \cdot t)^{-\kappa}.
$$
<ExampleCodeBlock anchor="trl.BEMACallback.example">
Example:
```python
from trl import BEMACallback
trainer = Trainer(..., callbacks=[BEMACallback()])
```
</ExampleCodeBlock>
</div>
## WeaveCallback[[trl.WeaveCallback]]
<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">
<docstring><name>class trl.WeaveCallback</name><anchor>trl.WeaveCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L550</source><parameters>[{"name": "trainer", "val": ": Trainer"}, {"name": "project_name", "val": ": str | None = None"}, {"name": "scorers", "val": ": dict[str, callable] | None = None"}, {"name": "generation_config", "val": ": transformers.generation.configuration_utils.GenerationConfig | None = None"}, {"name": "num_prompts", "val": ": int | None = None"}, {"name": "dataset_name", "val": ": str = 'eval_dataset'"}, {"name": "model_name", "val": ": str | None = None"}]</parameters><paramsdesc>- **trainer** (`Trainer`) --
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
column containing the prompts for generating completions.
- **project_name** (`str`, *optional*) --
Name of the Weave project where data will be logged. If not provided, will try to use existing weave client
or fall back to the active wandb run's project name. Raises an error if none of these are available.
- **scorers** (`dict[str, Callable]`, *optional*) --
Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions
only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should
have signature: `scorer(prompt: str, completion: str) -> float | int`
- **generation_config** ([GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig), *optional*) --
Generation config to use for generating completions.
- **num_prompts** (`int` or `None`, *optional*) --
Number of prompts to generate completions for. If not provided, defaults to the number of examples in the
evaluation dataset.
- **dataset_name** (`str`, *optional*, defaults to `"eval_dataset"`) --
Name for the dataset metadata in Weave.
- **model_name** (`str`, *optional*) --
Name for the model metadata in Weave. If not provided, attempts to extract from model config.</paramsdesc><paramgroups>0</paramgroups></docstring>
A [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback) that logs traces and evaluations to W&B Weave. The callback uses
https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation
step.
Supports two modes based on the `scorers` parameter:
- **Tracing Mode** (when scorers=None): Logs predictions for data exploration and analysis
- **Evaluation Mode** (when scorers provided): Logs predictions with scoring and summary metrics
Both modes use Weave's EvaluationLogger for structured, consistent data logging.
The callback logs data during evaluation phases (`on_evaluate`) rather than training steps, making it more
efficient and semantically correct. It gracefully handles missing weave installation by logging warnings and
skipping weave-specific functionality. It also checks for existing weave clients before initializing new ones.
<ExampleCodeBlock anchor="trl.WeaveCallback.example">
Usage:
```python
# Tracing mode (just log predictions)
trainer = DPOTrainer(...)
weave_callback = WeaveTraceCallback(trainer=trainer) # project_name optional
trainer.add_callback(weave_callback)
# Or specify a project name
weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training")
trainer.add_callback(weave_callback)
# Evaluation mode (log predictions + scores + summary)
def accuracy_scorer(prompt: str, completion: str) -> float:
# Your scoring logic here (metadata available via eval_attributes)
return score
weave_callback = WeaveTraceCallback(
trainer=trainer,
project_name="my-llm-training", # optional and needed only if weave client is not initialized
scorers={"accuracy": accuracy_scorer},
)
trainer.add_callback(weave_callback)
```
</ExampleCodeBlock>
<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">
<docstring><name>on_train_begin</name><anchor>trl.WeaveCallback.on_train_begin</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L681</source><parameters>[{"name": "args", "val": ""}, {"name": "state", "val": ""}, {"name": "control", "val": ""}, {"name": "**kwargs", "val": ""}]</parameters></docstring>
Initialize Weave when training begins.
</div></div>
<EditOnGithub source="https://github.com/huggingface/trl/blob/main/docs/source/callbacks.md" />

Xet Storage Details

Size:
15.7 kB
·
Xet hash:
2231d2bbe78fd6c14280aed8156677c8b3f26f742f55654682d83e596c07de6a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.