trl-mcsd / trl /trainer /callbacks.py
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pandas as pd
import torch
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import gather_object, is_wandb_available
from transformers import (
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import has_length
from transformers.utils import is_rich_available
from ..data_utils import maybe_apply_chat_template
from ..import_utils import is_weave_available
from ..models.utils import unwrap_model_for_generation
from .utils import log_table_to_comet_experiment
if is_rich_available():
from rich.columns import Columns
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
from rich.table import Table
if is_wandb_available():
import wandb
if is_weave_available():
import weave
from weave import EvaluationLogger
from weave.trace.context import weave_client_context
# Logger for module-level logging
logger = logging.getLogger(__name__)
def _generate_completions(
prompts: list[str],
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
accelerator: Accelerator,
generation_config: GenerationConfig | None,
batch_size: int = 1,
) -> list[str]:
"""
Generates completions for a list of pre-formatted prompts from the given model.
Args:
prompts (list[str]): A list of input prompts for which completions are to be generated.
model (PreTrainedModel): The pre-trained model to be used for generation.
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for encoding and decoding.
accelerator (Accelerator): The accelerator to be used for model execution.
generation_config (GenerationConfig): Configuration for text generation.
batch_size (int, optional): The number of prompts to process in each batch. Default is 1.
Returns:
list[str]: A list of generated text completions corresponding to the input prompts.
"""
completions = []
# TODO: Override model.generation_config with generation_kwargs
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
for idx in range(0, len(prompts), batch_size):
batch = prompts[idx : idx + batch_size]
tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device)
generations = unwrapped_model.generate(
**tokenized_batch,
generation_config=generation_config,
)
for prompt, generation in zip(tokenized_batch.input_ids, generations, strict=True):
# Remove prompt from generation
generation = generation[len(prompt) :]
completion = tokenizer.decode(generation, skip_special_tokens=True)
completions.append(completion)
return completions
class SyncRefModelCallback(TrainerCallback):
"""
Callback to synchronize the model with a reference model.
"""
def __init__(
self,
ref_model: PreTrainedModel | torch.nn.Module,
accelerator: Accelerator | None,
):
self.accelerator = accelerator
self.ref_model = ref_model
@staticmethod
def _sync_target_model(model, target_model, alpha):
for target_param, copy_param in zip(target_model.parameters(), model.parameters(), strict=True):
target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)
@staticmethod
def sync_target_model(model, target_model, alpha):
deepspeed_plugin = AcceleratorState().deepspeed_plugin
if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
import deepspeed
with deepspeed.zero.GatheredParameters(
list(model.parameters()) + list(target_model.parameters()), modifier_rank=0
):
if deepspeed.comm.get_rank() == 0:
SyncRefModelCallback._sync_target_model(model, target_model, alpha)
else:
SyncRefModelCallback._sync_target_model(model, target_model, alpha)
def on_step_end(self, args, state, control, **kwargs):
model: PreTrainedModel = kwargs["model"]
if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0:
if self.accelerator:
model = self.accelerator.unwrap_model(model)
self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha)
class RichProgressCallback(TrainerCallback):
"""
A [`TrainerCallback`] that displays the progress of training or evaluation using Rich.
"""
def __init__(self):
if not is_rich_available():
raise ImportError("RichProgressCallback requires the `rich` extra. To install, run `pip install rich`.")
self.training_bar = None
self.evaluation_bar = None
self.training_task = None
self.evaluation_task = None
self.rich_group = None
self.rich_console = None
self.training_status = None
self.current_step = None
def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
self.training_bar = Progress()
self.evaluation_bar = Progress()
self.rich_console = Console()
self.training_status = self.rich_console.status("Nothing to log yet ...")
self.rich_group = Live(Panel(Group(self.training_bar, self.evaluation_bar, self.training_status)))
self.rich_group.start()
self.training_task = self.training_bar.add_task("[blue]Training ", total=state.max_steps)
self.current_step = 0
def on_step_end(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
self.training_bar.update(self.training_task, advance=state.global_step - self.current_step, update=True)
self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if not state.is_world_process_zero:
return
if has_length(eval_dataloader):
if self.evaluation_task is None:
self.evaluation_task = self.evaluation_bar.add_task("[blue]Evaluation", total=len(eval_dataloader))
self.evaluation_bar.update(self.evaluation_task, advance=1, update=True)
def on_evaluate(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
if self.evaluation_task is not None:
self.evaluation_bar.remove_task(self.evaluation_task)
self.evaluation_task = None
def on_predict(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
if self.evaluation_task is not None:
self.evaluation_bar.remove_task(self.evaluation_task)
self.evaluation_task = None
def on_log(self, args, state, control, logs=None, **kwargs):
if not (state.is_world_process_zero and self.training_bar):
return
# Group keys by top-level prefix
grouped_logs = {}
for key, value in logs.items():
parts = key.split("/")
group = parts[0] if len(parts) > 1 else None
subkey = "/".join(parts[1:]) if len(parts) > 1 else key
grouped_logs.setdefault(group, {})[subkey] = value
# Create a table per group
tables = []
for group_name, metrics in grouped_logs.items():
table = Table(
title=f"[bold blue]{group_name}[/]" if group_name else None, header_style="bold magenta", box=None
)
table.add_column("Metric", justify="left", no_wrap=True)
table.add_column("Value", justify="right")
for metric, val in metrics.items():
formatted = f"{val:.3f}" if isinstance(val, (float, int)) else str(val)
table.add_row(metric, formatted)
tables.append(Panel(table, border_style="cyan", padding=(0, 1)))
# Arrange tables in columns using Columns
column_layout = Columns(tables, equal=False, expand=True)
self.training_status.update(
Panel(column_layout, title=f"[bold green]Step {state.global_step}[/bold green]", border_style="green")
)
def on_train_end(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
self.rich_group.stop()
self.training_bar = None
self.evaluation_bar = None
self.training_task = None
self.evaluation_task = None
self.rich_group = None
self.rich_console = None
self.training_status = None
self.current_step = None
class LogCompletionsCallback(TrainerCallback):
r"""
A [`~transformers.TrainerCallback`] that logs completions to Weights & Biases and/or Comet.
Usage:
```python
trainer = DPOTrainer(...)
completions_callback = LogCompletionsCallback(trainer=trainer)
trainer.add_callback(completions_callback)
```
Args:
trainer (`Trainer`):
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
column containing the prompts for generating completions.
generation_config ([`~transformers.GenerationConfig`], *optional*):
The generation config to use for generating completions.
num_prompts (`int`, *optional*):
The number of prompts to generate completions for. If not provided, defaults to the number of examples in
the evaluation dataset.
freq (`int`, *optional*):
The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`.
"""
def __init__(
self,
trainer: Trainer,
generation_config: GenerationConfig | None = None,
num_prompts: int | None = None,
freq: int | None = None,
):
self.trainer = trainer
self.generation_config = generation_config
self.freq = freq
self.table = []
self._last_logged_step = -1
if self.trainer.eval_dataset is None:
raise ValueError("Trainer must have an evaluation dataset to use the LogCompletionsCallback.")
else:
self.eval_dataset = self.trainer.eval_dataset
if num_prompts is not None:
self.eval_dataset = self.eval_dataset.select(range(num_prompts))
def on_step_end(self, args, state, control, **kwargs):
# Only log once per step (this method may be called multiple times)
if state.global_step == self._last_logged_step:
return
# Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps)
freq = self.freq or state.eval_steps
if state.global_step % freq != 0:
return
tokenizer = kwargs["processing_class"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
model = self.trainer.model_wrapped
with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts:
prompts = [maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] for prompt in prompts]
completions = _generate_completions(
prompts,
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
generation_config=self.generation_config,
batch_size=args.per_device_eval_batch_size,
)
completions = gather_object(completions)
prompts = gather_object(prompts)
# Build the data to log
if self.trainer.accelerator.is_main_process:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions, strict=True))
self.table.extend(data)
table = pd.DataFrame(columns=["step", "prompt", "completion"], data=self.table)
if "wandb" in args.report_to:
wandb.log({"completions": table})
if "comet_ml" in args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=table,
)
# Save the last logged step, so we don't log the same completions multiple times
self._last_logged_step = state.global_step
class WeaveCallback(TrainerCallback):
r"""
A [`~transformers.TrainerCallback`] that logs traces and evaluations to W&B Weave. The callback uses
https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation
step.
Supports two modes based on the `scorers` parameter:
- **Tracing Mode** (when scorers=None): Logs predictions for data exploration and analysis
- **Evaluation Mode** (when scorers provided): Logs predictions with scoring and summary metrics
Both modes use Weave's EvaluationLogger for structured, consistent data logging.
The callback logs data during evaluation phases (`on_evaluate`) rather than training steps, making it more
efficient and semantically correct. It gracefully handles missing weave installation by logging warnings and
skipping weave-specific functionality. It also checks for existing weave clients before initializing new ones.
Usage:
```python
# Tracing mode (just log predictions)
trainer = DPOTrainer(...)
weave_callback = WeaveTraceCallback(trainer=trainer) # project_name optional
trainer.add_callback(weave_callback)
# Or specify a project name
weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training")
trainer.add_callback(weave_callback)
# Evaluation mode (log predictions + scores + summary)
def accuracy_scorer(prompt: str, completion: str) -> float:
# Your scoring logic here (metadata available via eval_attributes)
return score
weave_callback = WeaveTraceCallback(
trainer=trainer,
project_name="my-llm-training", # optional and needed only if weave client is not initialized
scorers={"accuracy": accuracy_scorer},
)
trainer.add_callback(weave_callback)
```
Args:
trainer (`Trainer`):
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
column containing the prompts for generating completions.
project_name (`str`, *optional*):
Name of the Weave project where data will be logged. If not provided, will try to use existing weave client
or fall back to the active wandb run's project name. Raises an error if none of these are available.
scorers (`dict[str, Callable]`, *optional*):
Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions
only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should
have signature: `scorer(prompt: str, completion: str) -> float | int`
generation_config ([`~transformers.GenerationConfig`], *optional*):
Generation config to use for generating completions.
num_prompts (`int` or `None`, *optional*):
Number of prompts to generate completions for. If not provided, defaults to the number of examples in the
evaluation dataset.
dataset_name (`str`, *optional*, defaults to `"eval_dataset"`):
Name for the dataset metadata in Weave.
model_name (`str`, *optional*):
Name for the model metadata in Weave. If not provided, attempts to extract from model config.
"""
def __init__(
self,
trainer: Trainer,
project_name: str | None = None,
scorers: dict[str, callable] | None = None,
generation_config: GenerationConfig | None = None,
num_prompts: int | None = None,
dataset_name: str = "eval_dataset",
model_name: str | None = None,
):
self.trainer = trainer
self.project_name = project_name
self.scorers = scorers or {}
self.generation_config = generation_config
self.dataset_name = dataset_name
self.model_name = model_name
self._last_logged_step = -1
self._weave_initialized = False
self._eval_logger = None
if self.trainer.eval_dataset is None:
raise ValueError("Trainer must have an evaluation dataset to use the WeaveCallback.")
else:
self.eval_dataset = self.trainer.eval_dataset
if num_prompts is not None:
self.eval_dataset = self.eval_dataset.select(range(num_prompts))
def _initialize_weave(self):
"""Initialize Weave and EvaluationLogger if not already initialized."""
if not self._weave_initialized:
if not is_weave_available():
logger.warning("Weave is not available. Please install weave to enable logging: `pip install weave`")
return
if wc := weave_client_context.get_weave_client():
self._weave_client = wc
else:
if self.project_name is None:
if is_wandb_available():
if wandb.run is not None:
self.project_name = wandb.run.entity + "/" + wandb.run.project
logger.info(f"Using project name from active wandb run: {self.project_name}")
if self.project_name is None:
raise ValueError(
"No existing Weave client found and no project_name provided. "
"Please either initialize weave with `weave.init('project-name')`, "
"provide a project_name to the `WeaveTraceCallback`, "
"or ensure an active wandb run exists."
)
self._weave_client = weave.init(self.project_name)
logger.info(f"Initialized Weave with project: {self.project_name}")
if self.model_name is None:
self.model_name = getattr(self.trainer.model_wrapped.config, "_name_or_path", "unknown_model")
self._EvaluationLogger = EvaluationLogger
self._weave_initialized = True
@property
def is_evaluation_mode(self) -> bool:
"""True if scorers are provided (evaluation mode), False for tracing mode."""
return bool(self.scorers)
def on_train_begin(self, args, state, control, **kwargs):
"""Initialize Weave when training begins."""
self._initialize_weave()
def on_evaluate(self, args, state, control, **kwargs):
if state.global_step == self._last_logged_step:
return
self._initialize_weave()
if not self._weave_initialized:
logger.debug("Weave not initialized, skipping logging")
return
tokenizer = kwargs["processing_class"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
model = self.trainer.model_wrapped
with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts:
prompts = [maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] for prompt in prompts]
completions = _generate_completions(
prompts=prompts,
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
generation_config=self.generation_config,
batch_size=args.per_device_eval_batch_size,
)
all_prompts = gather_object(prompts)
all_completions = gather_object(completions)
if self.trainer.accelerator.is_main_process:
eval_attributes = {
"training_step": state.global_step,
"model_name": self.model_name,
"generation_config": (self.generation_config.to_dict() if self.generation_config else None),
}
eval_logger = self._EvaluationLogger(
model=self.model_name,
dataset=self.dataset_name,
eval_attributes=eval_attributes,
)
successful_predictions = 0
total_score_values = {} # For summary statistics
for prompt, completion in zip(all_prompts, all_completions, strict=True):
try:
pred_logger = eval_logger.log_prediction(inputs={"prompt": prompt}, output=completion)
if self.is_evaluation_mode:
for scorer_name, scorer_func in self.scorers.items():
try:
score = scorer_func(prompt, completion)
pred_logger.log_score(scorer=scorer_name, score=score)
if scorer_name not in total_score_values:
total_score_values[scorer_name] = []
total_score_values[scorer_name].append(score)
except Exception as scorer_e:
logger.warning(f"Failed to apply scorer '{scorer_name}': {scorer_e}")
pred_logger.finish()
successful_predictions += 1
except Exception as pred_e:
logger.warning(f"Failed to log prediction for prompt: {pred_e}")
# Continue with other predictions even if one fails
if self.is_evaluation_mode and total_score_values:
try:
summary_stats = {
"total_predictions": len(all_prompts),
"successful_predictions": successful_predictions,
}
for scorer_name, scores in total_score_values.items():
if scores: # Only if we have valid scores
summary_stats[f"avg_{scorer_name}"] = sum(scores) / len(scores)
eval_logger.log_summary(summary_stats)
except Exception as summary_e:
logger.warning(f"Failed to log summary: {summary_e}")
else:
try:
eval_logger.finish()
except Exception as finish_e:
logger.warning(f"Failed to finish evaluation logger: {finish_e}")
self._last_logged_step = state.global_step
class BEMACallback(TrainerCallback):
# docstyle-ignore
r"""
A [`~transformers.TrainerCallback`] that implements [BEMA](https://huggingface.co/papers/2508.00180)
(Bias-Corrected Exponential Moving Average) by [Adam Block](https://huggingface.co/abblock) and [Cyril
Zhang](https://huggingface.co/cyrilzhang). Code from https://github.com/abblock/bema under MIT license.
BEMA computes model weights that scale like:
$$
\theta_t' = \alpha_t \cdot (\theta_t - \theta_0) + \text{EMA}_t
$$
where \\( \theta_t \\) is the current model weights, \\( \theta_0 \\) is a snapshot of the model weights at the
first `update_after` step, \\( \text{EMA}_t \\) is the exponential moving average of the model weights, and
\\( \alpha_t \\) is a scaling factor that decays with the number of steps \\( t \\) as
$$
\alpha_t = (\rho + \gamma \cdot t)^{-\eta}.
$$
The EMA is computed as:
$$
\text{EMA}_t = (1 - \beta_t) \cdot \text{EMA}_{t-1} + \beta_t \cdot \theta_t
$$
where \\( \beta_t \\) is a decay factor that decays with the number of steps \\( t \\) as
$$
\beta_t = (\rho + \gamma \cdot t)^{-\kappa}.
$$
Args:
update_freq (`int`, *optional*, defaults to `400`):
Update the BEMA weights every X steps. Denoted this as \\( \phi \\) in the paper.
ema_power (`float`, *optional*, defaults to `0.5`):
Power for the EMA decay factor. Denoted \\( \kappa \\) in the paper. To disable EMA, set this to `0.0`.
bias_power (`float`, *optional*, defaults to `0.2`):
Power for the BEMA scaling factor. Denoted \\( \eta \\) in the paper. To disable BEMA, set this to `0.0`.
lag (`int`, *optional*, defaults to `10`):
Initial offset in the weight decay schedule that controls early-stage smoothness by acting as a virtual
starting age for the updates. Denoted as \\( \rho \\) in the paper.
update_after (`int`, *optional*, defaults to `0`):
Burn-in time before starting to update the BEMA weights. Denoted \\( \tau \\) in the paper.
multiplier (`float`, *optional*, defaults to `1.0`):
Initial value for the EMA decay factor. Denoted as \\( \gamma \\) in the paper.
min_ema_multiplier (`float`, *optional*, defaults to `0.0`):
Minimum value for the EMA decay factor.
device (`str`, *optional*, defaults to `"cpu"`):
Device to use for the BEMA buffers, e.g. `"cpu"` or `"cuda"`. Note that in most cases, this device SHOULD
BE DIFFERENT from the device used for training in order to avoid OOM.
Example:
```python
from trl import BEMACallback
trainer = Trainer(..., callbacks=[BEMACallback()])
```
"""
def __init__(
self,
update_freq: int = 400,
ema_power: float = 0.5,
bias_power: float = 0.2,
lag: int = 10,
update_after: int = 0,
multiplier: float = 1.0,
min_ema_multiplier: float = 0.0,
device: str = "cpu",
):
# User-provided hyperparams
self.update_freq = update_freq
self.ema_power = ema_power
self.bias_power = bias_power
self.lag = lag
self.update_after = update_after
self.multiplier = multiplier
self.min_ema_multiplier = min_ema_multiplier
self.device = device
# Internal state
self.param_names = [] # references to training model param names
self.thetat_params = [] # references to training model params
self.theta0_params = [] # θ₀ buffers (on self.device)
self.ema_params = [] # EMA buffers (on self.device)
self.running_model = None # a copy of the model to run BEMA on
@staticmethod
def _unwrap_model(model):
"""
Helper function to unwrap model from various wrappers including DataParallel, DistributedDataParallel,
DeepSpeed, and FSDP.
"""
# Handle DeepSpeed
if hasattr(model, "module") and hasattr(model, "engine"):
# DeepSpeed engine
return model.module
# Handle FSDP
if hasattr(model, "_fsdp_wrapped_module"):
# FSDP wrapped model
return model._fsdp_wrapped_module
# Handle DataParallel/DistributedDataParallel
if hasattr(model, "module"):
return model.module
return model
@torch.no_grad()
def on_train_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs
):
model = self._unwrap_model(model)
# Create a new instance and load state_dict
self.running_model = type(model)(model.config).to(self.device)
self.running_model.load_state_dict(model.state_dict())
# Cache trainable parameters once in a fixed order
for name, param in model.named_parameters():
if not param.requires_grad:
continue
self.param_names.append(name)
self.thetat_params.append(param)
# Clone θ₀ and EMA on the same device as model
theta0 = param.detach().clone().to(self.device)
self.theta0_params.append(theta0)
self.ema_params.append(theta0.clone()) # initialize EMA with θ₀
def _ema_beta(self, step: int) -> float:
"""Compute the EMA decay factor βₜ = (ρ + γ·t)⁻ᵏᵃᵖᵖᵃ."""
beta = (self.lag + self.multiplier * step) ** (-self.ema_power)
return max(beta, self.min_ema_multiplier)
def _bema_alpha(self, step: int) -> float:
"""Compute the BEMA scaling factor αₜ = (ρ + γ·t)⁻ᵉᵗᵃ."""
return (self.lag + self.multiplier * step) ** (-self.bias_power)
def _update_bema_weights(self, step: int):
beta = self._ema_beta(step)
alpha = self._bema_alpha(step)
# Compute EMA + BEMA in-place and write directly to running_model
for thetat, theta0, ema, run_param in zip(
self.thetat_params, self.theta0_params, self.ema_params, self.running_model.parameters(), strict=True
):
thetat = thetat.detach().to(self.device)
ema.mul_(1 - beta).add_(thetat, alpha=beta) # EMA update: ema = (1 - beta) * ema + beta * θₜ
run_param.copy_(ema + alpha * (thetat - theta0)) # BEMA update: run_param = ema + alpha * (θₜ - θ₀)
@torch.no_grad()
def on_step_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs
):
step = state.global_step
# If we haven't reached the update_after step, skip the BEMA update
if step < self.update_after:
return
# Snapshot θ₀ and EMA at first update
if step == self.update_after:
for thetat_param, theta0_param, ema_param in zip(
self.thetat_params, self.theta0_params, self.ema_params, strict=True
):
theta0_param.copy_(thetat_param)
ema_param.copy_(thetat_param)
# Update BEMA weights every `update_freq` steps
elif (step - self.update_after) % self.update_freq == 0:
self._update_bema_weights(step)
logger.info(f"Updated BEMA weights at step {step}")
@torch.no_grad()
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.is_world_process_zero:
save_directory = f"{args.output_dir}/bema"
self.running_model.save_pretrained(save_directory)
logger.info(f"Saved BEMA model to {save_directory}")