Buckets:
| # Callbacks | |
| ## SyncRefModelCallback[[trl.SyncRefModelCallback]] | |
| <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> | |
| <docstring><name>class trl.SyncRefModelCallback</name><anchor>trl.SyncRefModelCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L104</source><parameters>[{"name": "ref_model", "val": ": transformers.modeling_utils.PreTrainedModel | torch.nn.modules.module.Module"}, {"name": "accelerator", "val": ": accelerate.accelerator.Accelerator | None"}]</parameters></docstring> | |
| Callback to synchronize the model with a reference model. | |
| </div> | |
| ## RichProgressCallback[[trl.RichProgressCallback]] | |
| <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> | |
| <docstring><name>class trl.RichProgressCallback</name><anchor>trl.RichProgressCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L145</source><parameters>[]</parameters></docstring> | |
| A `TrainerCallback` that displays the progress of training or evaluation using Rich. | |
| </div> | |
| ## WinRateCallback[[trl.WinRateCallback]] | |
| <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> | |
| <docstring><name>class trl.WinRateCallback</name><anchor>trl.WinRateCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L266</source><parameters>[{"name": "judge", "val": ": BasePairwiseJudge"}, {"name": "trainer", "val": ": Trainer"}, {"name": "generation_config", "val": ": transformers.generation.configuration_utils.GenerationConfig | None = None"}, {"name": "num_prompts", "val": ": int | None = None"}, {"name": "shuffle_order", "val": ": bool = True"}, {"name": "use_soft_judge", "val": ": bool = False"}]</parameters><paramsdesc>- **judge** ([BasePairwiseJudge](/docs/trl/pr_4331/en/judges#trl.BasePairwiseJudge)) -- | |
| The judge to use for comparing completions. | |
| - **trainer** (`Trainer`) -- | |
| Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` | |
| column containing the prompts for generating completions. If the `Trainer` has a reference model (via the | |
| `ref_model` attribute), it will use this reference model for generating the reference completions; | |
| otherwise, it defaults to using the initial model. | |
| - **generation_config** ([GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig), *optional*) -- | |
| The generation config to use for generating completions. | |
| - **num_prompts** (`int`, *optional*) -- | |
| The number of prompts to generate completions for. If not provided, defaults to the number of examples in | |
| the evaluation dataset. | |
| - **shuffle_order** (`bool`, *optional*, defaults to `True`) -- | |
| Whether to shuffle the order of the completions before judging. | |
| - **use_soft_judge** (`bool`, *optional*, defaults to `False`) -- | |
| Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the | |
| second.</paramsdesc><paramgroups>0</paramgroups></docstring> | |
| A [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback) that computes the win rate of a model based on a reference. | |
| It generates completions using prompts from the evaluation dataset and compares the trained model's outputs against | |
| a reference. The reference is either the initial version of the model (before training) or the reference model, if | |
| available in the trainer. During each evaluation step, a judge determines how often the trained model's completions | |
| win against the reference using a judge. The win rate is then logged in the trainer's logs under the key | |
| `"eval_win_rate"`. | |
| <ExampleCodeBlock anchor="trl.WinRateCallback.example"> | |
| Usage: | |
| ```python | |
| trainer = DPOTrainer(...) | |
| judge = PairRMJudge() | |
| win_rate_callback = WinRateCallback(judge=judge, trainer=trainer) | |
| trainer.add_callback(win_rate_callback) | |
| ``` | |
| </ExampleCodeBlock> | |
| </div> | |
| ## LogCompletionsCallback[[trl.LogCompletionsCallback]] | |
| <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> | |
| <docstring><name>class trl.LogCompletionsCallback</name><anchor>trl.LogCompletionsCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L458</source><parameters>[{"name": "trainer", "val": ": Trainer"}, {"name": "generation_config", "val": ": transformers.generation.configuration_utils.GenerationConfig | None = None"}, {"name": "num_prompts", "val": ": int | None = None"}, {"name": "freq", "val": ": int | None = None"}]</parameters><paramsdesc>- **trainer** (`Trainer`) -- | |
| Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` | |
| column containing the prompts for generating completions. | |
| - **generation_config** ([GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig), *optional*) -- | |
| The generation config to use for generating completions. | |
| - **num_prompts** (`int`, *optional*) -- | |
| The number of prompts to generate completions for. If not provided, defaults to the number of examples in | |
| the evaluation dataset. | |
| - **freq** (`int`, *optional*) -- | |
| The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`.</paramsdesc><paramgroups>0</paramgroups></docstring> | |
| A [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback) that logs completions to Weights & Biases and/or Comet. | |
| <ExampleCodeBlock anchor="trl.LogCompletionsCallback.example"> | |
| Usage: | |
| ```python | |
| trainer = DPOTrainer(...) | |
| completions_callback = LogCompletionsCallback(trainer=trainer) | |
| trainer.add_callback(completions_callback) | |
| ``` | |
| </ExampleCodeBlock> | |
| </div> | |
| ## MergeModelCallback[[trl.MergeModelCallback]] | |
| <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> | |
| <docstring><name>class trl.MergeModelCallback</name><anchor>trl.MergeModelCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L779</source><parameters>[{"name": "merge_config", "val": ": MergeConfig | None = None"}, {"name": "merge_at_every_checkpoint", "val": ": bool = False"}, {"name": "push_to_hub", "val": ": bool = False"}]</parameters><paramsdesc>- **merge_config** (`MergeConfig`, *optional*) -- | |
| Configuration used for the merging process. If not provided, the default `MergeConfig` is used. | |
| - **merge_at_every_checkpoint** (`bool`, *optional*, defaults to `False`) -- | |
| Whether to merge the model at every checkpoint. | |
| - **push_to_hub** (`bool`, *optional*, defaults to `False`) -- | |
| Whether to push the merged model to the Hub after merging.</paramsdesc><paramgroups>0</paramgroups></docstring> | |
| A [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback) that merges the policy model (the model being trained) with another model based | |
| on a merge configuration. | |
| <ExampleCodeBlock anchor="trl.MergeModelCallback.example"> | |
| Example: | |
| ```python | |
| from trl.mergekit_utils import MergeConfig | |
| from trl import MergeModelCallback | |
| config = MergeConfig() | |
| merge_callback = MergeModelCallback(config) | |
| trainer = DPOTrainer(..., callbacks=[merge_callback]) | |
| ``` | |
| </ExampleCodeBlock> | |
| </div> | |
| ## BEMACallback[[trl.BEMACallback]] | |
| <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> | |
| <docstring><name>class trl.BEMACallback</name><anchor>trl.BEMACallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L840</source><parameters>[{"name": "update_freq", "val": ": int = 400"}, {"name": "ema_power", "val": ": float = 0.5"}, {"name": "bias_power", "val": ": float = 0.2"}, {"name": "lag", "val": ": int = 10"}, {"name": "update_after", "val": ": int = 0"}, {"name": "multiplier", "val": ": float = 1.0"}, {"name": "min_ema_multiplier", "val": ": float = 0.0"}, {"name": "device", "val": ": str = 'cpu'"}]</parameters><paramsdesc>- **update_freq** (`int`, *optional*, defaults to `400`) -- | |
| Update the BEMA weights every X steps. Denoted this as \\( \phi \\) in the paper. | |
| - **ema_power** (`float`, *optional*, defaults to `0.5`) -- | |
| Power for the EMA decay factor. Denoted \\( \kappa \\) in the paper. To disable EMA, set this to `0.0`. | |
| - **bias_power** (`float`, *optional*, defaults to `0.2`) -- | |
| Power for the BEMA scaling factor. Denoted \\( \eta \\) in the paper. To disable BEMA, set this to `0.0`. | |
| - **lag** (`int`, *optional*, defaults to `10`) -- | |
| Initial offset in the weight decay schedule that controls early-stage smoothness by acting as a virtual | |
| starting age for the updates. Denoted as \\( \rho \\) in the paper. | |
| - **update_after** (`int`, *optional*, defaults to `0`) -- | |
| Burn-in time before starting to update the BEMA weights. Denoted \\( \tau \\) in the paper. | |
| - **multiplier** (`float`, *optional*, defaults to `1.0`) -- | |
| Initial value for the EMA decay factor. Denoted as \\( \gamma \\) in the paper. | |
| - **min_ema_multiplier** (`float`, *optional*, defaults to `0.0`) -- | |
| Minimum value for the EMA decay factor. | |
| - **device** (`str`, *optional*, defaults to `"cpu"`) -- | |
| Device to use for the BEMA buffers, e.g. `"cpu"` or `"cuda"`. Note that in most cases, this device SHOULD | |
| BE DIFFERENT from the device used for training in order to avoid OOM.</paramsdesc><paramgroups>0</paramgroups></docstring> | |
| A [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback) that implements [BEMA](https://huggingface.co/papers/2508.00180) | |
| (Bias-Corrected Exponential Moving Average) by [Adam Block](https://huggingface.co/abblock) and [Cyril | |
| Zhang](https://huggingface.co/cyrilzhang). Code from https://github.com/abblock/bema under MIT license. | |
| BEMA computes model weights that scale like: | |
| $$ | |
| \theta_t' = \alpha_t \cdot (\theta_t - \theta_0) + \text{EMA}_t | |
| $$ | |
| where \\( \theta_t \\) is the current model weights, \\( \theta_0 \\) is a snapshot of the model weights at the | |
| first `update_after` step, \\( \text{EMA}_t \\) is the exponential moving average of the model weights, and | |
| \\( \alpha_t \\) is a scaling factor that decays with the number of steps \\( t \\) as | |
| $$ | |
| \alpha_t = (\rho + \gamma \cdot t)^{-\eta}. | |
| $$ | |
| The EMA is computed as: | |
| $$ | |
| \text{EMA}_t = (1 - \beta_t) \cdot \text{EMA}_{t-1} + \beta_t \cdot \theta_t | |
| $$ | |
| where \\( \beta_t \\) is a decay factor that decays with the number of steps \\( t \\) as | |
| $$ | |
| \beta_t = (\rho + \gamma \cdot t)^{-\kappa}. | |
| $$ | |
| <ExampleCodeBlock anchor="trl.BEMACallback.example"> | |
| Example: | |
| ```python | |
| from trl import BEMACallback | |
| trainer = Trainer(..., callbacks=[BEMACallback()]) | |
| ``` | |
| </ExampleCodeBlock> | |
| </div> | |
| ## WeaveCallback[[trl.WeaveCallback]] | |
| <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> | |
| <docstring><name>class trl.WeaveCallback</name><anchor>trl.WeaveCallback</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L550</source><parameters>[{"name": "trainer", "val": ": Trainer"}, {"name": "project_name", "val": ": str | None = None"}, {"name": "scorers", "val": ": dict[str, callable] | None = None"}, {"name": "generation_config", "val": ": transformers.generation.configuration_utils.GenerationConfig | None = None"}, {"name": "num_prompts", "val": ": int | None = None"}, {"name": "dataset_name", "val": ": str = 'eval_dataset'"}, {"name": "model_name", "val": ": str | None = None"}]</parameters><paramsdesc>- **trainer** (`Trainer`) -- | |
| Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` | |
| column containing the prompts for generating completions. | |
| - **project_name** (`str`, *optional*) -- | |
| Name of the Weave project where data will be logged. If not provided, will try to use existing weave client | |
| or fall back to the active wandb run's project name. Raises an error if none of these are available. | |
| - **scorers** (`dict[str, Callable]`, *optional*) -- | |
| Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions | |
| only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should | |
| have signature: `scorer(prompt: str, completion: str) -> float | int` | |
| - **generation_config** ([GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig), *optional*) -- | |
| Generation config to use for generating completions. | |
| - **num_prompts** (`int` or `None`, *optional*) -- | |
| Number of prompts to generate completions for. If not provided, defaults to the number of examples in the | |
| evaluation dataset. | |
| - **dataset_name** (`str`, *optional*, defaults to `"eval_dataset"`) -- | |
| Name for the dataset metadata in Weave. | |
| - **model_name** (`str`, *optional*) -- | |
| Name for the model metadata in Weave. If not provided, attempts to extract from model config.</paramsdesc><paramgroups>0</paramgroups></docstring> | |
| A [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback) that logs traces and evaluations to W&B Weave. The callback uses | |
| https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation | |
| step. | |
| Supports two modes based on the `scorers` parameter: | |
| - **Tracing Mode** (when scorers=None): Logs predictions for data exploration and analysis | |
| - **Evaluation Mode** (when scorers provided): Logs predictions with scoring and summary metrics | |
| Both modes use Weave's EvaluationLogger for structured, consistent data logging. | |
| The callback logs data during evaluation phases (`on_evaluate`) rather than training steps, making it more | |
| efficient and semantically correct. It gracefully handles missing weave installation by logging warnings and | |
| skipping weave-specific functionality. It also checks for existing weave clients before initializing new ones. | |
| <ExampleCodeBlock anchor="trl.WeaveCallback.example"> | |
| Usage: | |
| ```python | |
| # Tracing mode (just log predictions) | |
| trainer = DPOTrainer(...) | |
| weave_callback = WeaveTraceCallback(trainer=trainer) # project_name optional | |
| trainer.add_callback(weave_callback) | |
| # Or specify a project name | |
| weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training") | |
| trainer.add_callback(weave_callback) | |
| # Evaluation mode (log predictions + scores + summary) | |
| def accuracy_scorer(prompt: str, completion: str) -> float: | |
| # Your scoring logic here (metadata available via eval_attributes) | |
| return score | |
| weave_callback = WeaveTraceCallback( | |
| trainer=trainer, | |
| project_name="my-llm-training", # optional and needed only if weave client is not initialized | |
| scorers={"accuracy": accuracy_scorer}, | |
| ) | |
| trainer.add_callback(weave_callback) | |
| ``` | |
| </ExampleCodeBlock> | |
| <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> | |
| <docstring><name>on_train_begin</name><anchor>trl.WeaveCallback.on_train_begin</anchor><source>https://github.com/huggingface/trl/blob/vr_4331/trl/trainer/callbacks.py#L681</source><parameters>[{"name": "args", "val": ""}, {"name": "state", "val": ""}, {"name": "control", "val": ""}, {"name": "**kwargs", "val": ""}]</parameters></docstring> | |
| Initialize Weave when training begins. | |
| </div></div> | |
| <EditOnGithub source="https://github.com/huggingface/trl/blob/main/docs/source/callbacks.md" /> |
Xet Storage Details
- Size:
- 15.7 kB
- Xet hash:
- 2231d2bbe78fd6c14280aed8156677c8b3f26f742f55654682d83e596c07de6a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.