diff --git "a/unsloth_compiled_cache/UnslothDPOTrainer.py" "b/unsloth_compiled_cache/UnslothDPOTrainer.py"
new file mode 100644--- /dev/null
+++ "b/unsloth_compiled_cache/UnslothDPOTrainer.py"
@@ -0,0 +1,2874 @@
+"""
+2026.3.2
+2026.3.4
+5.3.0
+0.24.0
+__UNSLOTH_VERSIONING__
+"""
+
+# Unsloth auto generated code
+# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program. If not, see .
+
+from torch import Tensor
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from unsloth_zoo.temporary_patches.common import torch_compile
+from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
+from trl.trainer.dpo_trainer import (Any, AutoProcessor, BaseImageProcessor, BaseTrainer, Callable, DPOConfig, DPOTrainer, DataCollator, DataCollatorForPreference, DataLoader, Dataset, EvalLoopOutput, F, FDivergenceConstants, FDivergenceType, FeatureExtractionMixin, IterableDataset, Literal, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, Optional, PartialState, Path, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RunningMoments, SyncRefModelCallback, TrainerCallback, Union, autocast, cap_exp, contextmanager, create_model_from_path, create_reference_model, dataclass, defaultdict, disable_dropout_in_model, empty_cache, flush_left, flush_right, get_peft_model, inspect, is_comet_available, is_liger_kernel_available, is_mlflow_available, is_peft_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, nullcontext, pad, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_fsdp, prepare_model_for_kbit_training, random, selective_log_softmax, shift_tokens_right, textwrap, torch, tqdm, wandb, warnings, Any, AutoProcessor, BaseImageProcessor, Callable, DPOConfig, DPOTrainer, DataCollator, DataCollatorForPreference, Dataset, EvalLoopOutput, F, FDivergenceConstants, FeatureExtractionMixin, IterableDataset, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RunningMoments, SyncRefModelCallback, TrainerCallback, Union, create_model_from_path, create_reference_model, defaultdict, disable_dropout_in_model, is_comet_available, is_liger_kernel_available, is_mlflow_available, is_peft_available, is_wandb_available, logger, nn, pad, prepare_deepspeed, prepare_fsdp, torch, wandb, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, torch)
+
+
+import os
+import math
+import logging
+from typing import *
+from dataclasses import dataclass, field
+from packaging.version import Version
+import torch
+import numpy as np
+from contextlib import nullcontext
+from torch.nn import functional as F
+import inspect
+from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
+from transformers.training_args import ParallelMode
+from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
+
+# Wrap trainer with padding to right and enable training mode
+# Also patches W&B since multiple runs must use wandb.finish()
+import functools
+from types import MethodType
+try:
+ from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
+except:
+ def reset_unsloth_gradient_checkpointing_buffers(): pass
+def prepare_for_training_mode(f):
+ @functools.wraps(f)
+ def wrapper(self, *args, **kwargs):
+ # Enable training mode
+ _was_training = None
+ # Get gradient checkpointing setting from training arguments
+ use_gc = getattr(self.args, 'gradient_checkpointing', True)
+ if hasattr(self, 'model') and hasattr(self.model, "training"):
+ _was_training = self.model.training
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
+ self.model.for_training(use_gradient_checkpointing=use_gc)
+ output = f(self, *args, **kwargs)
+ # Restore previous mode when possible
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
+ if _was_training is False:
+ self.model.for_inference()
+ elif _was_training is True and hasattr(self.model, "for_training"):
+ self.model.for_training(use_gradient_checkpointing=use_gc)
+ # Reset gradient checkpointing buffers to free memory while staying ready for next run
+ try:
+ reset_unsloth_gradient_checkpointing_buffers()
+ except:
+ pass
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
+ try:
+ import wandb
+ wandb.finish()
+ except:
+ pass
+ return output
+ return wrapper
+pass
+
+torch_compile_options = {
+ "epilogue_fusion" : True,
+ "max_autotune" : False,
+ "shape_padding" : True,
+ "trace.enabled" : False,
+ "triton.cudagraphs" : False,
+}
+
+@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
+def chunked_hidden_states_selective_log_softmax(
+ hidden_states: torch.Tensor,
+ lm_head: torch.Tensor,
+ index: torch.Tensor,
+ chunks: int = 4,
+ logit_scale_multiply: float = 0.0,
+ logit_scale_divide: float = 0.0,
+ logit_softcapping: float = 0.0,
+ temperature: float = 1.0,
+) -> torch.Tensor:
+ # All Unsloth Zoo code licensed under AGPL3
+ flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
+ flat_index = index.reshape(-1)
+
+ chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
+ chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
+
+ all_per_token_logps = []
+
+ for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
+ chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
+
+ if logit_scale_multiply != 0.0:
+ chunk_logits = chunk_logits * logit_scale_multiply
+ if logit_scale_divide != 0.0:
+ chunk_logits = chunk_logits / logit_scale_divide
+ if logit_softcapping != 0.0:
+ chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
+
+ chunk_logits = chunk_logits.to(torch.float32)
+
+ if temperature != 1.0:
+ chunk_logits = chunk_logits / temperature
+
+ selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
+ logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
+ per_token_logps = selected_logits - logsumexp_values
+ all_per_token_logps.append(per_token_logps)
+
+ all_per_token_logps = torch.concat(all_per_token_logps)
+
+ all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
+ return all_per_token_logps
+
+@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
+def chunked_selective_log_softmax(logits, index):
+ # Split into 4 chunks only
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
+ all_per_token_logps = []
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
+ chunk_logits = chunk_logits.to(torch.float32)
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
+ per_token_logps = selected_logits - logsumexp_values
+ all_per_token_logps.append(per_token_logps)
+ pass
+ all_per_token_logps = torch.concat(all_per_token_logps)
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
+ return all_per_token_logps
+
+def calculate_pad_tokens_in_prompt(
+ input_ids: torch.Tensor,
+ logits_to_keep: int,
+ pad_token_id: int
+) -> torch.Tensor:
+ """
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
+ """
+ if logits_to_keep >= input_ids.shape[1]:
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
+
+ prompt_section = input_ids[:, :-logits_to_keep]
+
+ padding_mask = (prompt_section == pad_token_id)
+
+ pad_token_counts = padding_mask.sum(dim=1)
+
+ return pad_token_counts
+
+def create_completion_attention_mask(
+ completion_input_ids: torch.Tensor,
+ left_pad_tokens_per_prompt: torch.Tensor,
+ max_left_pad: int,
+ pad_token_id: int
+) -> torch.Tensor:
+ """
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
+
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
+ """
+ batch_size, completion_len = completion_input_ids.shape
+ device = completion_input_ids.device
+
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
+
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
+
+ non_padding_mask = (completion_input_ids != pad_token_id)
+
+ final_mask = shift_mask & non_padding_mask
+
+ return final_mask
+
+def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
+ """
+ Moves all padding tokens in each sequence of a batch to the right.
+ """
+ mask = (tensor != pad_id)
+ # Must do stable=True since binary mark is unordered
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
+ return packed_tensor
+
+def align_logprobs_with_mask(
+ logprob_tensor: torch.Tensor,
+ attention_mask: torch.Tensor,
+ pad_value: float = 0.0
+) -> torch.Tensor:
+ """
+ Aligns a log probability tensor with a given attention mask.
+ """
+
+ device = logprob_tensor.device
+ batch_size, logprob_seq_len = logprob_tensor.shape
+ mask_seq_len = attention_mask.shape[1]
+
+ padded_logprobs = torch.full(
+ attention_mask.shape,
+ fill_value=pad_value,
+ dtype=logprob_tensor.dtype,
+ device=device
+ )
+
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
+
+ cols = torch.arange(logprob_seq_len, device=device)
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
+
+ # Create destination row indices
+ # Shape: [batch_size, logprob_seq_len]
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
+
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
+ # Create a mask to identify only the indices that are within the bounds
+ # of the target tensor's sequence length.
+ valid_mask = dest_indices < mask_seq_len
+
+ # Use this mask to select only the valid row indices, column indices,
+ # and the corresponding values from the logprob tensor.
+ # This flattens the selected elements into 1D tensors.
+ valid_rows = row_indices[valid_mask]
+ valid_cols = dest_indices[valid_mask]
+ valid_vals = logprob_tensor[valid_mask]
+
+ # Place the valid values into their correct positions in the padded tensor
+ # using a single, efficient advanced indexing operation.
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
+
+ return padded_logprobs
+
+def autotune_batch_and_chunks(
+ total_input_rows,
+ seq_len,
+ hidden_size,
+ vocab_size,
+ dtype_bytes=16,
+ multiplier=None
+):
+ if multiplier is None:
+ final_m = max(4, seq_len // 4096)
+ else:
+ final_m = multiplier
+
+ if torch.cuda.is_available():
+ free_bytes, _ = torch.cuda.mem_get_info()
+ limit_gb = (free_bytes / (1024**3))*.80
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
+ # For XPU: estimate free memory from total - reserved
+ total_mem = torch.xpu.get_device_properties(0).total_memory
+ reserved_mem = torch.xpu.memory_reserved()
+ free_bytes = total_mem - reserved_mem
+ limit_gb = (free_bytes / (1024**3)) * 0.80
+ else:
+ # Fallback: assume 8GB available
+ limit_gb = 8.0
+
+ bytes_to_gb = 1024**3
+
+ b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
+
+ hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
+
+ base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
+ logits_gb = base_logits / final_m
+
+ total_mem_gb = hidden_gb + logits_gb
+
+ valid_mask = total_mem_gb <= limit_gb
+ valid_indices = torch.nonzero(valid_mask, as_tuple=False)
+
+ if valid_indices.shape[0] == 0:
+ #This means your GPU will OOM
+ return 4, final_m
+
+ best_idx = valid_indices[0].item()
+ final_b = int(b_vals[best_idx].item())
+
+ return final_b, final_m
+
+def sanitize_logprob(logprob):
+ """Local port of trl.scripts.vllm_serve.sanitize_logprob.
+ Filters NaN logprobs from vLLM outputs."""
+ value = logprob.logprob
+ if math.isnan(value):
+ logging.getLogger(__name__).warning(
+ f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
+ )
+ return None
+ return value
+@dataclass
+class UnslothDPOConfig(DPOConfig):
+ """
+
+ Configuration class for the [`DPOTrainer`].
+
+ This class includes only the parameters that are specific to DPO training. For a full list of training arguments,
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
+ differ from those in [`~transformers.TrainingArguments`].
+
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
+ command line.
+
+ Parameters:
+ > Parameters that control the model and reference model
+
+ model_init_kwargs (`dict[str, Any]`, *optional*):
+ Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the
+ [`DPOTrainer`] is provided as a string.
+ ref_model_init_kwargs (`dict[str, Any]`, *optional*):
+ Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument of the
+ [`DPOTrainer`] is provided as a string.
+ model_adapter_name (`str`, *optional*):
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
+ ref_adapter_name (`str`, *optional*):
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
+ force_use_ref_model (`bool`, *optional*, defaults to `False`):
+ If you provide a PEFT model as the active model and wish to use a different model for the `ref_model`, set
+ this flag to `True`.
+ disable_dropout (`bool`, *optional*, defaults to `True`):
+ Whether to disable dropout in the model and reference model.
+ use_logits_to_keep (`bool`, *optional*, defaults to `False`):
+ If `True`, only a specified number of logits are computed in the forward pass. This can be useful for
+ saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios
+ when working with very long prompts where labels are ignored (-100).
+
+ > Parameters that control the data preprocessing
+
+ dataset_num_proc (`int`, *optional*):
+ Number of processes to use for processing the dataset.
+ pad_token (`str`, *optional*):
+ Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
+ it falls back to `processing_class.eos_token`.
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
+ Padding value to use for labels.
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
+ Maximum length of the prompt.
+ max_completion_length (`int`, *optional*):
+ Maximum length of the completion.
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
+ Maximum length of the full sequence (prompt + completion).
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
+ Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and
+ `"keep_start"`.
+ padding_free (`bool`, *optional*, defaults to `False`):
+ Whether to perform forward passes without padding by flattening all sequences in the batch into a single
+ continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
+ supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened
+ batch structure.
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
+ Whether to precompute the log probabilities from the reference model. Setting this to `True` allows
+ training without needing the reference model during training, which can help reduce GPU memory usage. If
+ set to `False` (default), the reference model will be used during training to compute log probabilities
+ on-the-fly.
+ precompute_ref_batch_size (`int`, *optional*):
+ Batch size to use when precomputing reference model log probabilities. This can be set higher than the
+ training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for
+ training and `per_device_eval_batch_size` for evaluation.
+ tools (`Optional[list[Union[dict, Callable]]]`, *optional*):
+ List of tools (callable functions) that will be accessible to the model. If the template does not support
+ function calling, this argument will have no effect.
+
+ > Parameters that control the training
+
+ loss_type (`str` or `list[str]`, *optional*, defaults to `"sigmoid"`):
+ Type of loss to use. Possible values are:
+
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
+ - `"hinge"`: hinge loss on the normalized likelihood from the
+ [SLiC](https://huggingface.co/papers/2305.10425) paper.
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
+ - `"exo_pair"`: pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper.
+ - `"nca_pair"`: pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper.
+ - `"robust"`: unbiased estimate of the DPO loss that is robust to preference noise from the [Robust
+ DPO](https://huggingface.co/papers/2403.00409) paper.
+ - `"bco_pair"`: pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper.
+ - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675)
+ paper.
+ - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
+ - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882)
+ paper.
+ - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the
+ [DiscoPOP](https://huggingface.co/papers/2406.08414) paper.
+ - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
+ - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
+ - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss).
+
+ Multiple loss types can be combined using comma separation (e.g., `["sigmoid", "bco_pair", "sft"]` for
+ [MPO](https://huggingface.co/papers/2411.10442)). The `loss_weights` parameter can be used to specify
+ corresponding weights for each loss type.
+
+ use_liger_loss (`bool`, *optional*, defaults to `False`):
+ Whether to use Liger loss.
+ base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
+ Name of the attribute in the model that contains the base model. This is used to get the base model from
+ the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
+ beta (`float`, *optional*, defaults to `0.1`):
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
+ the [paper](https://huggingface.co/papers/2310.12036).
+ f_divergence_type ([`FDivergenceType`] or `str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`):
+ Type of f-divergence regularization function to compute divergence between policy and reference model.
+ f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`):
+ α coefficient in the α-divergence u^-α regularization function for DPO loss.
+ reference_free (`bool`, *optional*, defaults to `False`):
+ Whether to ignore the provided reference model and implicitly use a reference model that assigns equal
+ probability to all responses.
+ label_smoothing (`float`, *optional*, defaults to `0.0`):
+ Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and [Robust
+ DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`.
+ use_weighting (`bool`, *optional*, defaults to `False`):
+ Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827).
+ rpo_alpha (`float`, *optional*):
+ α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the
+ weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the
+ DPO loss. The paper recommends `rpo_alpha=1.0`.
+ ld_alpha (`float`, *optional*):
+ α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting
+ of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose
+ part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between
+ `0.0` and `1.0`.
+ discopop_tau (`float`, *optional*, defaults to `0.05`):
+ τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls
+ the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`.
+ loss_weights (`list[float]`, *optional*):
+ List of loss weights for multi-loss combinations. Used when combining multiple loss types. Example: `[0.8,
+ 0.2, 1.0]` for [MPO](https://huggingface.co/papers/2411.10442). If not provided, defaults to equal weights
+ (`1.0`) for all loss types.
+ sync_ref_model (`bool`, *optional*, defaults to `False`):
+ Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
+ the `ref_model_mixup_alpha` parameter. This synchronization originates from the
+ [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
+ ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
+ α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
+ between the current policy and the previous reference policy during updates. The reference policy is
+ updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
+ must set `sync_ref_model=True`.
+ ref_model_sync_steps (`int`, *optional*, defaults to `512`):
+ τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
+ frequently the current policy is synchronized with the reference policy. To use this parameter, you must
+ set `sync_ref_model=True`.
+
+ > Parameters that control the logging
+
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
+ Whether to generate and log completions from both the model and the reference model to W&B or Comet during
+ evaluation.
+
+ > Deprecated parameters
+
+ padding_value:
+
+
+
+ This parameter is deprecated and will be removed in version 0.25.0. Use `pad_token` (`str`) instead.
+
+
+
+ """
+ vllm_sampling_params: Optional[Any] = field(
+ default = None,
+ metadata = {'help': 'vLLM SamplingParams'},
+ )
+ unsloth_num_chunks : Optional[int] = field(
+ default = -1,
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
+ )
+ unsloth_logit_chunk_multiplier : Optional[int] = field(
+ default = None,
+ metadata = {'help': 'Multiplier for chunked logit computations.'},
+ )
+ unsloth_grpo_mini_batch : Optional[int] = field(
+ default = None,
+ metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
+ )
+ max_seq_length : Optional[int] = field(
+ default = None,
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
+ )
+ def __init__(
+ self,
+ output_dir = None,
+ per_device_train_batch_size = 4,
+ num_train_epochs = 3.0,
+ max_steps = -1,
+ learning_rate = 5e-05,
+ lr_scheduler_type = 'linear',
+ lr_scheduler_kwargs = None,
+ warmup_steps = 0.1,
+ optim = 'adamw_8bit',
+ optim_args = None,
+ weight_decay = 0.01,
+ adam_beta1 = 0.9,
+ adam_beta2 = 0.999,
+ adam_epsilon = 1e-08,
+ optim_target_modules = None,
+ gradient_accumulation_steps = 2,
+ average_tokens_across_devices = True,
+ max_grad_norm = 1.0,
+ label_smoothing_factor = 0.0,
+ bf16 = False,
+ fp16 = False,
+ bf16_full_eval = False,
+ fp16_full_eval = False,
+ tf32 = None,
+ gradient_checkpointing = True,
+ gradient_checkpointing_kwargs = None,
+ torch_compile = False,
+ torch_compile_backend = None,
+ torch_compile_mode = None,
+ use_liger_kernel = False,
+ liger_kernel_config = None,
+ use_cache = False,
+ neftune_noise_alpha = None,
+ torch_empty_cache_steps = 250,
+ auto_find_batch_size = False,
+ logging_strategy = 'steps',
+ logging_steps = 1,
+ logging_first_step = False,
+ log_on_each_node = True,
+ logging_nan_inf_filter = False,
+ include_num_input_tokens_seen = False,
+ log_level = 'passive',
+ log_level_replica = 'warning',
+ disable_tqdm = None,
+ report_to = 'none',
+ run_name = None,
+ project = 'huggingface',
+ trackio_space_id = 'trackio',
+ eval_strategy = 'no',
+ eval_steps = None,
+ eval_delay = 0,
+ per_device_eval_batch_size = 4,
+ prediction_loss_only = False,
+ eval_on_start = False,
+ eval_do_concat_batches = True,
+ eval_use_gather_object = False,
+ eval_accumulation_steps = 2,
+ batch_eval_metrics = False,
+ save_only_model = False,
+ save_strategy = 'steps',
+ save_steps = 500,
+ save_on_each_node = False,
+ save_total_limit = None,
+ enable_jit_checkpoint = False,
+ push_to_hub = False,
+ hub_token = None,
+ hub_private_repo = None,
+ hub_model_id = None,
+ hub_strategy = 'every_save',
+ hub_always_push = False,
+ hub_revision = None,
+ load_best_model_at_end = False,
+ metric_for_best_model = None,
+ greater_is_better = None,
+ ignore_data_skip = False,
+ restore_callback_states_from_checkpoint = False,
+ full_determinism = False,
+ seed = 3407,
+ data_seed = 3407,
+ use_cpu = False,
+ accelerator_config = None,
+ parallelism_config = None,
+ dataloader_drop_last = False,
+ dataloader_num_workers = 0,
+ dataloader_pin_memory = True,
+ dataloader_persistent_workers = False,
+ dataloader_prefetch_factor = None,
+ remove_unused_columns = True,
+ label_names = None,
+ train_sampling_strategy = 'random',
+ length_column_name = 'length',
+ ddp_find_unused_parameters = None,
+ ddp_bucket_cap_mb = None,
+ ddp_broadcast_buffers = None,
+ ddp_backend = None,
+ ddp_timeout = 1800,
+ fsdp = None,
+ fsdp_config = None,
+ deepspeed = None,
+ debug = '',
+ skip_memory_metrics = True,
+ do_train = False,
+ do_eval = False,
+ do_predict = False,
+ resume_from_checkpoint = None,
+ warmup_ratio = None,
+ logging_dir = None,
+ local_rank = -1,
+ model_init_kwargs = None,
+ ref_model_init_kwargs = None,
+ model_adapter_name = None,
+ ref_adapter_name = None,
+ force_use_ref_model = False,
+ disable_dropout = True,
+ use_logits_to_keep = False,
+ dataset_num_proc = None,
+ pad_token = None,
+ label_pad_token_id = -100,
+ max_prompt_length = 512,
+ max_completion_length = None,
+ max_length = 1024,
+ truncation_mode = 'keep_end',
+ padding_free = None,
+ precompute_ref_log_probs = False,
+ precompute_ref_batch_size = None,
+ tools = None,
+ use_liger_loss = False,
+ base_model_attribute_name = 'model',
+ beta = 0.1,
+ f_alpha_divergence_coef = 1.0,
+ reference_free = False,
+ label_smoothing = 0.0,
+ use_weighting = False,
+ rpo_alpha = None,
+ ld_alpha = None,
+ discopop_tau = 0.05,
+ loss_weights = None,
+ sync_ref_model = False,
+ ref_model_mixup_alpha = 0.6,
+ ref_model_sync_steps = 512,
+ generate_during_eval = False,
+ padding_value = None,
+ vllm_sampling_params = None,
+ unsloth_num_chunks = -1,
+ unsloth_logit_chunk_multiplier = None,
+ unsloth_grpo_mini_batch = None,
+ max_seq_length = None,
+ **kwargs,
+ ):
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
+ if num_train_epochs is None:
+ num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
+ output_dir = 'unsloth_training_checkpoints'
+ save_strategy = 'no'
+ import multiprocessing as _mp
+ if _mp.get_start_method() != 'fork':
+ dataset_num_proc = None
+ elif dataset_num_proc is None:
+ import psutil
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
+ if memory_gb_left <= 2: dataset_num_proc = 1
+ else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
+
+ super().__init__(
+ output_dir = output_dir,
+ per_device_train_batch_size = per_device_train_batch_size,
+ num_train_epochs = num_train_epochs,
+ max_steps = max_steps,
+ learning_rate = learning_rate,
+ lr_scheduler_type = lr_scheduler_type,
+ lr_scheduler_kwargs = lr_scheduler_kwargs,
+ warmup_steps = warmup_steps,
+ optim = optim,
+ optim_args = optim_args,
+ weight_decay = weight_decay,
+ adam_beta1 = adam_beta1,
+ adam_beta2 = adam_beta2,
+ adam_epsilon = adam_epsilon,
+ optim_target_modules = optim_target_modules,
+ gradient_accumulation_steps = gradient_accumulation_steps,
+ average_tokens_across_devices = average_tokens_across_devices,
+ max_grad_norm = max_grad_norm,
+ label_smoothing_factor = label_smoothing_factor,
+ bf16 = bf16,
+ fp16 = fp16,
+ bf16_full_eval = bf16_full_eval,
+ fp16_full_eval = fp16_full_eval,
+ tf32 = tf32,
+ gradient_checkpointing = gradient_checkpointing,
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
+ torch_compile = torch_compile,
+ torch_compile_backend = torch_compile_backend,
+ torch_compile_mode = torch_compile_mode,
+ use_liger_kernel = use_liger_kernel,
+ liger_kernel_config = liger_kernel_config,
+ use_cache = use_cache,
+ neftune_noise_alpha = neftune_noise_alpha,
+ torch_empty_cache_steps = torch_empty_cache_steps,
+ auto_find_batch_size = auto_find_batch_size,
+ logging_strategy = logging_strategy,
+ logging_steps = logging_steps,
+ logging_first_step = logging_first_step,
+ log_on_each_node = log_on_each_node,
+ logging_nan_inf_filter = logging_nan_inf_filter,
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
+ log_level = log_level,
+ log_level_replica = log_level_replica,
+ disable_tqdm = disable_tqdm,
+ report_to = report_to,
+ run_name = run_name,
+ project = project,
+ trackio_space_id = trackio_space_id,
+ eval_strategy = eval_strategy,
+ eval_steps = eval_steps,
+ eval_delay = eval_delay,
+ per_device_eval_batch_size = per_device_eval_batch_size,
+ prediction_loss_only = prediction_loss_only,
+ eval_on_start = eval_on_start,
+ eval_do_concat_batches = eval_do_concat_batches,
+ eval_use_gather_object = eval_use_gather_object,
+ eval_accumulation_steps = eval_accumulation_steps,
+ batch_eval_metrics = batch_eval_metrics,
+ save_only_model = save_only_model,
+ save_strategy = save_strategy,
+ save_steps = save_steps,
+ save_on_each_node = save_on_each_node,
+ save_total_limit = save_total_limit,
+ enable_jit_checkpoint = enable_jit_checkpoint,
+ push_to_hub = push_to_hub,
+ hub_token = hub_token,
+ hub_private_repo = hub_private_repo,
+ hub_model_id = hub_model_id,
+ hub_strategy = hub_strategy,
+ hub_always_push = hub_always_push,
+ hub_revision = hub_revision,
+ load_best_model_at_end = load_best_model_at_end,
+ metric_for_best_model = metric_for_best_model,
+ greater_is_better = greater_is_better,
+ ignore_data_skip = ignore_data_skip,
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
+ full_determinism = full_determinism,
+ seed = seed,
+ data_seed = data_seed,
+ use_cpu = use_cpu,
+ accelerator_config = accelerator_config,
+ parallelism_config = parallelism_config,
+ dataloader_drop_last = dataloader_drop_last,
+ dataloader_num_workers = dataloader_num_workers,
+ dataloader_pin_memory = dataloader_pin_memory,
+ dataloader_persistent_workers = dataloader_persistent_workers,
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
+ remove_unused_columns = remove_unused_columns,
+ label_names = label_names,
+ train_sampling_strategy = train_sampling_strategy,
+ length_column_name = length_column_name,
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
+ ddp_backend = ddp_backend,
+ ddp_timeout = ddp_timeout,
+ fsdp = fsdp,
+ fsdp_config = fsdp_config,
+ deepspeed = deepspeed,
+ debug = debug,
+ skip_memory_metrics = skip_memory_metrics,
+ do_train = do_train,
+ do_eval = do_eval,
+ do_predict = do_predict,
+ resume_from_checkpoint = resume_from_checkpoint,
+ warmup_ratio = warmup_ratio,
+ logging_dir = logging_dir,
+ local_rank = local_rank,
+ model_init_kwargs = model_init_kwargs,
+ ref_model_init_kwargs = ref_model_init_kwargs,
+ model_adapter_name = model_adapter_name,
+ ref_adapter_name = ref_adapter_name,
+ force_use_ref_model = force_use_ref_model,
+ disable_dropout = disable_dropout,
+ use_logits_to_keep = use_logits_to_keep,
+ dataset_num_proc = dataset_num_proc,
+ pad_token = pad_token,
+ label_pad_token_id = label_pad_token_id,
+ max_prompt_length = max_prompt_length,
+ max_completion_length = max_completion_length,
+ max_length = max_length,
+ truncation_mode = truncation_mode,
+ padding_free = padding_free,
+ precompute_ref_log_probs = precompute_ref_log_probs,
+ precompute_ref_batch_size = precompute_ref_batch_size,
+ tools = tools,
+ use_liger_loss = use_liger_loss,
+ base_model_attribute_name = base_model_attribute_name,
+ beta = beta,
+ f_alpha_divergence_coef = f_alpha_divergence_coef,
+ reference_free = reference_free,
+ label_smoothing = label_smoothing,
+ use_weighting = use_weighting,
+ rpo_alpha = rpo_alpha,
+ ld_alpha = ld_alpha,
+ discopop_tau = discopop_tau,
+ loss_weights = loss_weights,
+ sync_ref_model = sync_ref_model,
+ ref_model_mixup_alpha = ref_model_mixup_alpha,
+ ref_model_sync_steps = ref_model_sync_steps,
+ generate_during_eval = generate_during_eval,
+ padding_value = padding_value,**kwargs)
+ self.vllm_sampling_params = vllm_sampling_params
+ self.unsloth_num_chunks = unsloth_num_chunks
+ if unsloth_grpo_mini_batch is not None:
+ if self.generation_batch_size >= unsloth_grpo_mini_batch:
+ self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
+ else:
+ raise ValueError(
+ f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
+ f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
+ )
+ self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
+ self.max_seq_length = max_seq_length
+
+pass
+
+class _UnslothDPOTrainer(BaseTrainer):
+ """"""
+
+ _tag_names = ["trl", "dpo"]
+ _name = "DPO"
+ _paper = {
+ "title": "Direct Preference Optimization: Your Language Model is Secretly a Reward Model",
+ "id": "2305.18290",
+ # docstyle-ignore
+ "citation": textwrap.dedent("""\
+ @inproceedings{rafailov2023direct,
+ title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
+ author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
+ year = 2023,
+ booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
+ url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
+ editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
+ }"""),
+ }
+
+ def __init__(
+ self,
+ model: Union[str, nn.Module, PreTrainedModel],
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
+ args: Optional[DPOConfig] = None,
+ data_collator: Optional[DataCollator] = None, # type: ignore
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
+ processing_class: Optional[
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
+ ] = None,
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
+ callbacks: Optional[list[TrainerCallback]] = None,
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
+ optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ peft_config: Optional["PeftConfig"] = None,
+ ):
+ # Args
+ if args is None:
+ model_name = model if isinstance(model, str) else model.config._name_or_path
+ model_name = model_name.split("/")[-1]
+ args = DPOConfig(f"{model_name}-DPO")
+
+ # Model and reference model
+ if isinstance(model, str):
+ model = create_model_from_path(model, **args.model_init_kwargs or {})
+ else:
+ if args.model_init_kwargs is not None:
+ logger.warning(
+ "You passed `model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. "
+ "The `model_init_kwargs` will be ignored."
+ )
+ model_id = model.config._name_or_path
+ if isinstance(ref_model, str):
+ ref_model = create_model_from_path(ref_model, **args.ref_model_init_kwargs or {})
+ else:
+ if args.ref_model_init_kwargs is not None:
+ logger.warning(
+ "You passed `ref_model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. "
+ "The `ref_model_init_kwargs` will be ignored."
+ )
+ if ref_model is model:
+ raise ValueError(
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
+ "same as `model`, you can simply omit the `ref_model` argument and it will be created for you."
+ )
+
+ # Processing class
+ if processing_class is None:
+ processing_class = AutoProcessor.from_pretrained(model_id)
+
+ # Handle pad token for processors or tokenizers
+ if isinstance(processing_class, ProcessorMixin):
+ tokenizer = processing_class.tokenizer
+ self._is_vlm = True
+ elif isinstance(processing_class, PreTrainedTokenizerBase):
+ tokenizer = processing_class
+ self._is_vlm = False
+ else:
+ raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
+
+ # Get the pad token: if not provided, use the one from the processing class or the eos token
+ # if the processing class does not have a pad token.
+ if args.padding_value is not None: # deprecated, will be removed in 0.26.0.
+ warnings.warn(
+ "The `padding_value` argument is deprecated and will be removed in version 0.26.0. Please use "
+ "`pad_token` (str) instead."
+ )
+ self.pad_token_id = args.padding_value
+ else:
+ pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
+ self.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
+ if self.pad_token_id is None:
+ raise ValueError(
+ f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
+ f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
+ "in the vocabulary before using it as a padding token."
+ )
+
+ # PEFT configuration and model wrapping
+ model = self._prepare_peft_model(model, ref_model, peft_config, args)
+
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available() or is_mlflow_available()):
+ raise ValueError(
+ "`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed."
+ " Please install `wandb`, `mlflow` or `comet-ml` to resolve."
+ )
+
+ self.is_encoder_decoder = model.config.is_encoder_decoder
+ self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys()
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
+ self.model_adapter_name = args.model_adapter_name
+ self.ref_adapter_name = args.ref_adapter_name
+ self.reference_free = args.reference_free
+
+ if ref_model:
+ self.ref_model = ref_model
+ elif self.is_peft_model or args.precompute_ref_log_probs:
+ # The `model` with adapters turned off will be used as the reference model
+ self.ref_model = None
+ else:
+ self.ref_model = create_reference_model(model)
+
+ # Disable dropout in the model and reference model
+ if args.disable_dropout:
+ disable_dropout_in_model(model)
+ if self.ref_model is not None:
+ disable_dropout_in_model(self.ref_model)
+
+ # Liger kernel
+ if args.use_liger_loss:
+ if not is_liger_kernel_available():
+ raise ImportError(
+ "You set `use_liger_loss=True` but the liger kernel is not available. "
+ "Please install liger-kernel first: `pip install liger-kernel`"
+ )
+ if args.loss_type not in ["sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"]:
+ raise ValueError(
+ "You set `use_liger_loss=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. "
+ "Please set `loss_type='[sigmoid | apo_zero | apo_down | sppo_hard | nca_pair]'` to use the liger kernel."
+ )
+ self.dpo_loss_fn = LigerFusedLinearDPOLoss(
+ ignore_index=args.label_pad_token_id,
+ beta=args.beta,
+ use_ref_model=not args.reference_free,
+ average_log_prob=False,
+ loss_type=args.loss_type,
+ )
+ # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
+ # input tensor associated with the key "input_ids". However, in DPO, the sampled data does not include the
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
+ # that the warning has already been issued.
+ model.warnings_issued["estimate_tokens"] = True
+
+ # Data collator
+ if data_collator is None:
+ data_collator = DataCollatorForPreference(pad_token_id=self.pad_token_id)
+
+ self.generate_during_eval = args.generate_during_eval
+ self.label_pad_token_id = args.label_pad_token_id
+ self.max_prompt_length = args.max_prompt_length
+ self.max_completion_length = args.max_completion_length
+ self.max_length = args.max_length
+ self.truncation_mode = args.truncation_mode
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
+ self.use_logits_to_keep = args.use_logits_to_keep
+
+ if args.padding_free:
+ if model.config._attn_implementation != "flash_attention_2":
+ logger.warning(
+ "Padding-free training is enabled, but the attention implementation is not set to "
+ "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
+ "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
+ "other implementations may lead to unexpected behavior. To ensure compatibility, set "
+ "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
+ "attention mechanism can handle flattened sequences."
+ )
+ self.padding_free = args.padding_free
+
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
+ # keep track of first called to avoid computation of future calls
+ self._precomputed_train_ref_log_probs = False
+ self._precomputed_eval_ref_log_probs = False
+
+ self.beta = args.beta
+ self.label_smoothing = args.label_smoothing
+ self.loss_type = args.loss_type if isinstance(args.loss_type, list) else [args.loss_type]
+ self.loss_weights = args.loss_weights
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
+ self.use_weighting = args.use_weighting
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
+ logger.warning(
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
+ "loss.",
+ )
+ for loss_type in self.loss_type:
+ if (
+ loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"]
+ and args.label_smoothing > 0
+ ):
+ logger.warning(
+ f"You are using the {loss_type} loss type that does not support label smoothing. The "
+ "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this "
+ "warning.",
+ )
+ if loss_type == "kto_pair":
+ raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.")
+
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
+ self.f_divergence_type = args.f_divergence_type
+ self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}
+ self.dataset_num_proc = args.dataset_num_proc
+
+ # Dataset preparation
+ train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
+ if eval_dataset is not None:
+ if isinstance(eval_dataset, dict):
+ eval_dataset = {
+ key: self._prepare_dataset(dataset, processing_class, args, key)
+ for key, dataset in eval_dataset.items()
+ }
+ else:
+ eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
+
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ processing_class=processing_class,
+ compute_metrics=compute_metrics,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
+ )
+
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
+ # self.model_accepts_loss_kwargs to False to enable scaling.
+ self.model_accepts_loss_kwargs = False
+
+ # Add tags for models that have been loaded with the correct transformers version
+ if hasattr(self.model, "add_model_tags"):
+ self.model.add_model_tags(self._tag_names)
+
+ if not hasattr(self, "accelerator"):
+ raise AttributeError(
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
+ )
+
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
+ if self.is_deepspeed_enabled:
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
+ raise ValueError(
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
+ )
+
+ if self.ref_model is None:
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
+ raise ValueError(
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
+ )
+ if args.sync_ref_model:
+ raise ValueError(
+ "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`."
+ )
+ else:
+ if self.is_deepspeed_enabled:
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
+ elif self.is_fsdp_enabled:
+ self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+
+ if args.sync_ref_model:
+ if self.precompute_ref_log_probs:
+ raise ValueError(
+ "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`."
+ )
+
+ self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
+
+ if "bco_pair" in self.loss_type:
+ self.running = RunningMoments(self.accelerator)
+
+ @property
+ def padding_value(self):
+ warnings.warn(
+ "The `padding_value` property is deprecated and will be removed in version 0.26.0. Please use "
+ "`pad_token_id` instead.",
+ )
+ return self.pad_token_id
+
+ @padding_value.setter
+ def padding_value(self, value):
+ warnings.warn(
+ "The `padding_value` property is deprecated and will be removed in version 0.26.0. Please use "
+ "`pad_token_id` instead.",
+ )
+ self.pad_token_id = value
+
+ def _prepare_peft_model(
+ self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig
+ ) -> PreTrainedModel:
+ """Prepares a model for PEFT training."""
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
+ # has been called in order to properly call autocast if needed.
+ self._peft_has_been_casted_to_bf16 = False
+
+ if not is_peft_available() and peft_config is not None:
+ raise ValueError(
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
+ )
+ elif is_peft_available() and peft_config is not None:
+ # if model is a peft model and we have a peft_config, we merge and unload it first
+ if isinstance(model, PeftModel):
+ model = model.merge_and_unload()
+
+ if ref_model is not None and not args.force_use_ref_model:
+ raise ValueError(
+ "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
+ " model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init."
+ " if you want to use a different ref_model."
+ )
+
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
+ _support_gc_kwargs = hasattr(
+ args, "gradient_checkpointing_kwargs"
+ ) and "gradient_checkpointing_kwargs" in list(
+ inspect.signature(prepare_model_for_kbit_training).parameters
+ )
+
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
+
+ if _support_gc_kwargs:
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
+
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
+
+ else:
+ model = self._prepare_gradient_checkpointing(model, args)
+
+ # get peft model with the given config
+ model = get_peft_model(model, peft_config)
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
+ peft_module_casting_to_bf16(model)
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
+ self._peft_has_been_casted_to_bf16 = True
+
+ else:
+ model = self._prepare_gradient_checkpointing(model, args)
+
+ return model
+
+ def _prepare_gradient_checkpointing(self, model: PreTrainedModel, args: DPOConfig):
+ """Prepare the gradienting checkpointing for the model."""
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
+ # fail or completely fail.
+ if args.gradient_checkpointing:
+ # For backward compatibility with older versions of transformers
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ return model
+
+ def _prepare_dataset(
+ self,
+ dataset: Union[Dataset, IterableDataset],
+ processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
+ args: DPOConfig,
+ dataset_name: str,
+ ) -> Union[Dataset, IterableDataset]:
+ # Build the kwargs for the `map` function
+ map_kwargs = {}
+ if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size
+ map_kwargs["num_proc"] = args.dataset_num_proc
+ map_kwargs["writer_batch_size"] = 10
+
+ with PartialState().main_process_first():
+ # Extract prompt if needed
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
+ map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
+ dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
+
+ # Apply the chat template if needed
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
+ map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
+ dataset = dataset.map(
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs
+ )
+
+ # Tokenize the dataset
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
+ map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
+
+ dataset = dataset.map(
+ self.tokenize_row if not self.is_vision_model else self.process_row,
+ remove_columns=["chosen", "rejected"],
+ fn_kwargs={
+ "processing_class": processing_class,
+ "max_prompt_length": args.max_prompt_length,
+ "max_completion_length": args.max_completion_length,
+ # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
+ "add_special_tokens": False,
+ },
+ **map_kwargs,
+ )
+
+ return dataset
+
+ @staticmethod
+ def tokenize_row(
+ features: dict[str, str],
+ processing_class: PreTrainedTokenizerBase,
+ max_prompt_length: Optional[int] = None,
+ max_completion_length: Optional[int] = None,
+ add_special_tokens: bool = True,
+ ) -> dict[str, list[int]]:
+ """
+ Tokenize a row of the dataset.
+
+ Args:
+ features (`dict[str, str]`):
+ Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`.
+ processing_class ([`~transformers.PreTrainedTokenizerBase`]):
+ Processing class used to process the data.
+ max_prompt_length (`int` or `None`):
+ Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated.
+ max_completion_length (`int` or `None`):
+ Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
+ add_special_tokens (`bool`):
+ Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`,
+ the prompt sequence will have a bos token prepended and an eos token appended. In any case, the
+ completion sequences will have an eos token appended.
+
+ Returns:
+ `dict[str, list[int]]`:
+ Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and
+ `"rejected_input_ids".
+
+ Example:
+ ```python
+ >>> from transformers import GPT2Tokenizer
+
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}
+ >>> DPOTrainer.tokenize_row(
+ ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False
+ ... )
+ {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]}
+ ```
+ """
+ tokenizer = processing_class # the processing class is a tokenizer
+ prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
+ chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"]
+ rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"]
+
+ # Add special tokens (typically for encoder-decoder models)
+ if add_special_tokens:
+ if tokenizer.bos_token_id is not None:
+ prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
+ if tokenizer.eos_token_id is not None:
+ prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
+ chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
+ rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
+
+ # Truncate prompt and completion sequences
+ if max_prompt_length is not None:
+ prompt_input_ids = prompt_input_ids[-max_prompt_length:]
+ if max_completion_length is not None:
+ chosen_input_ids = chosen_input_ids[:max_completion_length]
+ rejected_input_ids = rejected_input_ids[:max_completion_length]
+
+ return {
+ "prompt_input_ids": prompt_input_ids,
+ "chosen_input_ids": chosen_input_ids,
+ "rejected_input_ids": rejected_input_ids,
+ }
+
+ @staticmethod
+ def process_row(
+ features: dict[str, str],
+ processing_class: PreTrainedTokenizerBase,
+ max_prompt_length: Optional[int] = None,
+ max_completion_length: Optional[int] = None,
+ add_special_tokens: bool = True,
+ ) -> dict[str, list[int]]:
+ """
+ Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information.
+ """
+ processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor
+ processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False)
+
+ prompt_input_ids = processed_features["input_ids"][0]
+ pixel_values = processed_features["pixel_values"][0]
+ chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"]
+ rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"]
+
+ # Add special tokens (typically for encoder-decoder models)
+ if add_special_tokens:
+ if tokenizer.bos_token_id is not None:
+ prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
+ if tokenizer.eos_token_id is not None:
+ prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
+ chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
+ rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
+
+ # Truncate prompt and completion sequences
+ if max_prompt_length is not None:
+ prompt_input_ids = prompt_input_ids[-max_prompt_length:]
+ if max_completion_length is not None:
+ chosen_input_ids = chosen_input_ids[:max_completion_length]
+ rejected_input_ids = rejected_input_ids[:max_completion_length]
+
+ output = {
+ "prompt_input_ids": prompt_input_ids,
+ "pixel_values": pixel_values,
+ "chosen_input_ids": chosen_input_ids,
+ "rejected_input_ids": rejected_input_ids,
+ }
+
+ if "pixel_attention_mask" in processed_features:
+ output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
+ if "image_sizes" in processed_features:
+ output["image_sizes"] = processed_features["image_sizes"][0]
+ if "token_type_ids" in processed_features:
+ output["token_type_ids"] = processed_features["token_type_ids"][0]
+
+ return output
+
+ def _set_signature_columns_if_needed(self):
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
+ # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
+ # Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override.
+ if self._signature_columns is None:
+ self._signature_columns = [
+ "prompt_input_ids",
+ "chosen_input_ids",
+ "rejected_input_ids",
+ "image_sizes",
+ "token_type_ids",
+ "ref_chosen_logps",
+ "ref_rejected_logps",
+ ]
+
+ def get_train_dataloader(self) -> DataLoader:
+ """
+ Returns the training [`~torch.utils.data.DataLoader`].
+
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
+ """
+
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
+ batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size
+ dataloader_params = {
+ "batch_size": batch_size,
+ "collate_fn": self.data_collator,
+ "num_workers": self.args.dataloader_num_workers,
+ "pin_memory": self.args.dataloader_pin_memory,
+ "shuffle": False,
+ }
+
+ # prepare dataloader
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
+
+ ref_chosen_logps = []
+ ref_rejected_logps = []
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
+ ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch)
+ ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics(
+ (ref_chosen_logp, ref_rejected_logp)
+ )
+ ref_chosen_logps.append(ref_chosen_logp.cpu())
+ ref_rejected_logps.append(ref_rejected_logp.cpu())
+
+ # Unnecessary cache clearing to avoid OOM
+ empty_cache()
+ self.accelerator.free_memory()
+
+ all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy()
+ all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy()
+
+ self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps)
+ self.train_dataset = self.train_dataset.add_column(
+ name="ref_rejected_logps", column=all_ref_rejected_logps
+ )
+
+ self._precomputed_train_ref_log_probs = True
+
+ return super().get_train_dataloader()
+
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
+ """
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
+
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
+
+ Args:
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
+ """
+ if eval_dataset is None and self.eval_dataset is None:
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
+
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
+ batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size
+ dataloader_params = {
+ "batch_size": batch_size,
+ "collate_fn": self.data_collator,
+ "num_workers": self.args.dataloader_num_workers,
+ "pin_memory": self.args.dataloader_pin_memory,
+ "shuffle": False,
+ }
+
+ # prepare dataloader
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
+
+ ref_chosen_logps = []
+ ref_rejected_logps = []
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
+ ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch)
+ ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics(
+ (ref_chosen_logp, ref_rejected_logp)
+ )
+ ref_chosen_logps.append(ref_chosen_logp.cpu())
+ ref_rejected_logps.append(ref_rejected_logp.cpu())
+
+ all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy()
+ all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy()
+
+ eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps)
+ eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps)
+
+ # Save calculated ref_chosen_logps and ref_rejected_logps to the eval_dataset for subsequent runs
+ if self.eval_dataset is not None:
+ self.eval_dataset = eval_dataset
+ self._precomputed_eval_ref_log_probs = True
+
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
+
+ @contextmanager
+ def null_ref_context(self):
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
+ with (
+ self.accelerator.unwrap_model(self.model).disable_adapter()
+ if self.is_peft_model and not self.ref_adapter_name
+ else nullcontext()
+ ):
+ if self.ref_adapter_name:
+ self.model.set_adapter(self.ref_adapter_name)
+ yield
+ if self.ref_adapter_name:
+ self.model.set_adapter(self.model_adapter_name or "default")
+
+ def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> tuple[torch.Tensor, torch.Tensor]:
+ """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
+ compte_ref_context_manager = (
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
+ )
+ with torch.no_grad(), compte_ref_context_manager:
+ if self.ref_model is None:
+ with self.null_ref_context():
+ ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True)
+ else:
+ ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True)
+ return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"]
+
+ @staticmethod
+ def concatenated_inputs(
+ batch: dict[str, Union[list, torch.LongTensor]], padding_value: int
+ ) -> dict[str, torch.LongTensor]:
+ """
+ Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt and
+ completion sequences.
+
+ Args:
+ batch (`dict[str, Union[list, torch.LongTensor]]`):
+ A batch of input data. The batch must contain the following keys:
+
+ - `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input
+ IDs.
+ - `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen
+ completion input IDs.
+ - `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected
+ completion input IDs.
+ - `"prompt_pixel_values"` (optional): Tensor for pixel values, if available.
+ - `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available.
+
+ padding_value (`int`):
+ The padding value to use for the concatenated completion sequences (`chosen_input_ids` and
+ `rejected_input_ids`).
+
+ Returns:
+ `dict[str, torch.LongTensor]`: A dictionary containing:
+
+ - `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`.
+ - `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 *
+ batch_size, max_completion_length)`.
+ - `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size,
+ prompt_length)`.
+ - `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 *
+ batch_size, max_completion_length)`.
+ - `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present.
+ - `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if
+ `"prompt_pixel_attention_mask"` are present.
+
+ Notes:
+ The completion input IDs and attention masks are padded to the maximum completion length of the chosen or
+ rejected sequences.
+ """
+ output = {}
+
+ # For the prompt, the input_ids are the same for both the chosen and rejected responses
+ output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0)
+ output["prompt_attention_mask"] = torch.cat(
+ [batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0
+ )
+ if "pixel_values" in batch:
+ output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0)
+
+ if "pixel_attention_mask" in batch:
+ output["pixel_attention_mask"] = torch.cat(
+ [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0
+ )
+ if "image_sizes" in batch:
+ output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
+ if "token_type_ids" in batch:
+ output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"]))
+
+ # Concatenate the chosen and rejected completions
+ max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
+ output["completion_input_ids"] = torch.cat(
+ (
+ pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value),
+ pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value),
+ ),
+ )
+ output["completion_attention_mask"] = torch.cat(
+ (
+ pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0),
+ pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0),
+ ),
+ )
+
+ return output
+
+ def dpo_loss(
+ self,
+ chosen_logps: torch.FloatTensor,
+ rejected_logps: torch.FloatTensor,
+ ref_chosen_logps: torch.FloatTensor,
+ ref_rejected_logps: torch.FloatTensor,
+ loss_type: str = "sigmoid",
+ model_output: dict[str, torch.FloatTensor] = None,
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ """
+ Compute the DPO loss for a batch of policy and reference model log probabilities.
+
+ Args:
+ chosen_logps (`torch.FloatTensor`):
+ Log probabilities of the model for the chosen responses. Shape: `(batch_size,)`.
+ rejected_logps (`torch.FloatTensor`):
+ Log probabilities of the model for the rejected responses. Shape: `(batch_size,)`.
+ ref_chosen_logps (`torch.FloatTensor`):
+ Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`.
+ ref_rejected_logps (`torch.FloatTensor`):
+ Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`.
+ loss_type (`str`, defaults to `"sigmoid"`):
+ The type of loss to compute. One of:
+ - `"sigmoid"`: Sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
+ - `"hinge"`: Hinge loss on the normalized likelihood from the
+ [SLiC](https://huggingface.co/papers/2305.10425) paper.
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
+ - `"exo_pair"`: Pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper.
+ - `"nca_pair"`: Pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper.
+ - `"robust"`: Unbiased estimate of the DPO loss that is robust to preference noise from the [Robust
+ DPO](https://huggingface.co/papers/2403.00409) paper.
+ - `"bco_pair"`: Pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper.
+ - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675)
+ paper.
+ - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
+ - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882)
+ paper.
+ - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the
+ [DiscoPOP](https://huggingface.co/papers/2406.08414) paper.
+ - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
+ - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
+ - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss).
+ model_output (`dict[str, torch.FloatTensor]`, *optional*):
+ The output of the model's forward pass. This is used to compute auxiliary losses if enabled.
+
+ Returns:
+ A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. The losses tensor contains the DPO
+ loss for each example in the batch. The `chosen_rewards` and `rejected_rewards` tensors contain the rewards
+ for the chosen and rejected responses, respectively.
+ """
+ device = self.accelerator.device
+
+ # Get the log ratios for the chosen and rejected responses
+ chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device)
+ rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device)
+
+ if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE:
+ # The alpha-divergence formula: (1 - u^-alpha) / alpha
+ # The divergence difference between the chosen and rejected sample is:
+ # (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha
+ # = (u[l]^-alpha - u[w]^-alpha) / alpha
+ # where u[w] and u[l] are the policy/reference probability ratios
+ # for the chosen and rejected samples, respectively.
+ alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
+ if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params:
+ alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY])
+ logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef
+ else:
+ logratios = chosen_logps - rejected_logps
+ if self.reference_free:
+ ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device)
+ else:
+ ref_logratios = ref_chosen_logps - ref_rejected_logps
+
+ logratios = logratios.to(self.accelerator.device)
+ ref_logratios = ref_logratios.to(self.accelerator.device)
+ logits = logratios - ref_logratios
+
+ if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE:
+ # The js-divergence formula: log(2 * u / (1 + u))
+ # The divergence difference between the chosen and rejected sample is:
+ # log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l]))
+ # = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l]))
+ # where u[w] and u[l] are the policy/reference probability ratios
+ # for the chosen and rejected samples, respectively.
+ logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)
+
+ # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the
+ # labels and calculates a conservative DPO loss.
+ if loss_type == "sigmoid":
+ losses = (
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
+ )
+
+ elif loss_type == "robust":
+ losses = (
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ + F.logsigmoid(-self.beta * logits) * self.label_smoothing
+ ) / (1 - 2 * self.label_smoothing)
+
+ elif loss_type == "exo_pair":
+ # eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856
+ import math
+
+ if self.label_smoothing == 0:
+ self.label_smoothing = 1e-3
+ losses = (self.beta * logits).sigmoid() * (
+ F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing)
+ ) + (-self.beta * logits).sigmoid() * (F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing))
+
+ elif loss_type == "hinge":
+ losses = torch.relu(1 - self.beta * logits)
+
+ elif loss_type == "ipo":
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
+ losses = (logits - 1 / (2 * self.beta)) ** 2
+
+ elif loss_type == "bco_pair":
+ chosen_logratios = chosen_logps - ref_chosen_logps
+ rejected_logratios = rejected_logps - ref_rejected_logps
+ chosen_rewards = self.beta * chosen_logratios
+ rejected_rewards = self.beta * rejected_logratios
+ rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
+ self.running.update(rewards)
+ delta = self.running.mean
+ losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
+ -(self.beta * rejected_logratios - delta)
+ )
+
+ elif loss_type == "sppo_hard":
+ # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
+ # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
+ # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
+ # set to 1 for the winner and 0 for the loser.
+ a = chosen_logps - ref_chosen_logps
+ b = rejected_logps - ref_rejected_logps
+ losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2
+
+ elif loss_type == "nca_pair":
+ chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta
+ rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta
+ losses = (
+ -F.logsigmoid(chosen_rewards)
+ - 0.5 * F.logsigmoid(-chosen_rewards)
+ - 0.5 * F.logsigmoid(-rejected_rewards)
+ )
+
+ elif loss_type == "aot_pair":
+ chosen_logratios = chosen_logps - ref_chosen_logps
+ rejected_logratios = rejected_logps - ref_rejected_logps
+ chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0)
+ rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0)
+ delta = chosen_logratios_sorted - rejected_logratios_sorted
+ losses = (
+ -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing)
+ - F.logsigmoid(-self.beta * delta) * self.label_smoothing
+ )
+
+ elif loss_type == "aot":
+ logratios = chosen_logps - rejected_logps
+ ref_logratios = ref_chosen_logps - ref_rejected_logps
+ logratios_sorted, _ = torch.sort(logratios, dim=0)
+ ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0)
+ delta = logratios_sorted - ref_logratios_sorted
+ losses = (
+ -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing)
+ - F.logsigmoid(-self.beta * delta) * self.label_smoothing
+ )
+
+ elif loss_type == "apo_zero":
+ # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
+ # Use this loss when you believe the chosen outputs are better than your model's default output
+ losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood
+ losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood
+ losses = losses_chosen + losses_rejected
+
+ elif loss_type == "apo_down":
+ # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
+ # Use this loss when you believe the chosen outputs are worse than your model's default output.
+ # Decrease chosen likelihood and decrease rejected likelihood more
+ losses_chosen = F.sigmoid(self.beta * chosen_logratios)
+ losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios))
+ losses = losses_chosen + losses_rejected
+
+ elif loss_type == "discopop":
+ # Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414)
+ # This loss was discovered with LLM discovery
+ logratios = chosen_logps - rejected_logps
+ ref_logratios = ref_chosen_logps - ref_rejected_logps
+ logits = logratios - ref_logratios
+ logits = logits * self.beta
+ # Modulate the mixing coefficient based on the log ratio magnitudes
+ log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau)
+ logistic_component = -F.logsigmoid(logits)
+ exp_component = torch.exp(-logits)
+ # Blend between logistic and exponential component based on log ratio modulation
+ losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation
+
+ elif loss_type == "sft":
+ # SFT loss is the negative log likelihood loss on chosen responses
+ # This acts as the generation loss component in MPO
+ sft_loss = model_output["nll_loss"]
+ # Create losses tensor with same shape as other losses (per-sample)
+ batch_size = chosen_logps.shape[0]
+ losses = sft_loss.expand(batch_size)
+ # For SFT, we don't have preference rewards, so use zeros
+ chosen_rewards = torch.zeros_like(chosen_logps)
+ rejected_rewards = torch.zeros_like(rejected_logps)
+
+ else:
+ raise ValueError(
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', "
+ "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'discopop', 'apo_zero', "
+ "'apo_down', 'sft']"
+ )
+
+ chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
+ rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()
+
+ return losses, chosen_rewards, rejected_rewards
+
+ def _compute_loss_liger(
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
+ ) -> dict[str, torch.Tensor]:
+ unwrapped_model = self.accelerator.unwrap_model(model)
+ concatenated_batch = self.concatenated_inputs(batch, padding_value=self.pad_token_id)
+
+ model_kwargs = {}
+ if self.aux_loss_enabled:
+ model_kwargs["output_router_logits"] = True
+
+ # Add the pixel values and attention masks for vision models
+ if "pixel_values" in concatenated_batch:
+ model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
+ if "pixel_attention_mask" in concatenated_batch:
+ model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
+ if "image_sizes" in concatenated_batch:
+ model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
+
+ prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
+ completion_attention_mask = concatenated_batch["completion_attention_mask"]
+
+ if self.is_encoder_decoder:
+ # 1. Get encoder outputs
+ encoder_outputs = unwrapped_model.get_encoder()(
+ concatenated_batch["prompt_input_ids"],
+ attention_mask=concatenated_batch["prompt_attention_mask"],
+ return_dict=True,
+ )
+ # 2. Prepare decoder inputs
+ decoder_input_ids = shift_tokens_right(
+ concatenated_batch["completion_input_ids"],
+ unwrapped_model.config.decoder_start_token_id,
+ )
+ # 3. Get decoder outputs
+ decoder_outputs = unwrapped_model.get_decoder()(
+ input_ids=decoder_input_ids,
+ attention_mask=concatenated_batch["completion_attention_mask"],
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
+ encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
+ use_cache=False,
+ )
+ hidden_states = decoder_outputs.last_hidden_state
+
+ ref_hidden_states = None
+ if not self.reference_free and self.ref_model is not None:
+ unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
+ ref_encoder_outputs = unwrapped_ref_model.get_encoder()(
+ concatenated_batch["prompt_input_ids"],
+ attention_mask=concatenated_batch["prompt_attention_mask"],
+ return_dict=True,
+ )
+ ref_decoder_outputs = unwrapped_ref_model.get_decoder()(
+ input_ids=decoder_input_ids,
+ attention_mask=concatenated_batch["completion_attention_mask"],
+ encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
+ encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
+ use_cache=False,
+ )
+ ref_hidden_states = ref_decoder_outputs.last_hidden_state
+ elif not self.reference_free:
+ with self.null_ref_context():
+ ref_encoder_outputs = unwrapped_model.get_encoder()(
+ concatenated_batch["prompt_input_ids"],
+ attention_mask=concatenated_batch["prompt_attention_mask"],
+ return_dict=True,
+ )
+ ref_decoder_outputs = unwrapped_model.get_decoder()(
+ input_ids=decoder_input_ids,
+ attention_mask=concatenated_batch["completion_attention_mask"],
+ encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
+ encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
+ use_cache=False,
+ )
+ ref_hidden_states = ref_decoder_outputs.last_hidden_state
+
+ labels = concatenated_batch["completion_input_ids"]
+ loss_mask = completion_attention_mask.bool()
+ else:
+ # For decoder-only models
+ input_ids = torch.cat(
+ (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1
+ )
+ attention_mask = torch.cat(
+ (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]),
+ dim=1,
+ )
+ # Mask the prompt but not the completion for the loss
+ loss_mask = torch.cat(
+ (torch.zeros_like(prompt_attention_mask), completion_attention_mask),
+ dim=1,
+ )
+
+ # Flush and truncate
+ if self.max_length is not None and self.max_length < attention_mask.size(1):
+ if self.truncation_mode == "keep_start":
+ # Flush left to reduce the memory usage
+ # [[0, 0, x, x, x, x], -> [[x, x, x, x],
+ # [0, x, x, x, 0, 0]] [x, x, x, 0]]
+ attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
+ attention_mask = attention_mask[:, : self.max_length]
+ input_ids = input_ids[:, : self.max_length]
+ loss_mask = loss_mask[:, : self.max_length]
+ elif self.truncation_mode == "keep_end":
+ # Flush right before truncating left, then flush left
+ # [[0, 0, x, x, x, x], -> [[0, 0, x, x],
+ # [0, x, x, x, 0, 0]] [0, x, x, x]]
+ attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
+ input_ids = input_ids[:, -self.max_length :]
+ attention_mask = attention_mask[:, -self.max_length :]
+ loss_mask = loss_mask[:, -self.max_length :]
+ attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
+ else:
+ raise ValueError(
+ f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
+ "'keep_start']."
+ )
+ else:
+ # Flush left to reduce the memory usage
+ # [[0, 0, x, x, x, x], -> [[x, x, x, x],
+ # [0, x, x, x, 0, 0]] [x, x, x, 0]]
+ attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
+
+ # Add logits_to_keep optimization
+ if self.use_logits_to_keep:
+ first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
+ logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1
+ model_kwargs["logits_to_keep"] = logits_to_keep
+
+ model_kwargs["output_hidden_states"] = True
+
+ # Add padding-free training support
+ if self.padding_free:
+ input_ids = input_ids[attention_mask.bool()].unsqueeze(0)
+ loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0)
+ position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1
+ model_kwargs["position_ids"] = position_ids
+ else:
+ model_kwargs["attention_mask"] = attention_mask
+
+ # Get the base model outputs (before LM head)
+ if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None:
+ base_model = unwrapped_model.get_decoder()
+ else:
+ base_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name)
+ base_model = getattr(unwrapped_model, base_attr, unwrapped_model)
+
+ outputs = base_model(
+ input_ids,
+ use_cache=False,
+ **model_kwargs,
+ )
+ hidden_states = outputs.last_hidden_state[:, :-1]
+
+ # Get reference hidden states if needed
+ ref_hidden_states = None
+ if not self.reference_free and self.ref_model is not None:
+ unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
+ if hasattr(unwrapped_ref_model, "get_decoder") and unwrapped_ref_model.get_decoder() is not None:
+ ref_base_model = unwrapped_ref_model.get_decoder()
+ else:
+ ref_attr = getattr(unwrapped_ref_model, "base_model_prefix", self.args.base_model_attribute_name)
+ ref_base_model = getattr(unwrapped_ref_model, ref_attr, unwrapped_ref_model)
+
+ ref_outputs = ref_base_model(
+ input_ids,
+ use_cache=False,
+ **model_kwargs,
+ )
+ ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]
+ elif not self.reference_free:
+ if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None:
+ ref_base_model = unwrapped_model.get_decoder()
+ else:
+ ref_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name)
+ ref_base_model = getattr(unwrapped_model, ref_attr, unwrapped_model)
+ with self.null_ref_context():
+ ref_outputs = ref_base_model(
+ input_ids,
+ use_cache=False,
+ **model_kwargs,
+ )
+ ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]
+
+ masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id)
+ labels = masked_input_ids[:, 1:] # Shift right for casual LM
+
+ # Get the LM head
+ lm_head = unwrapped_model.get_output_embeddings()
+
+ # Get reference model weights if needed
+ ref_weight = None
+ ref_bias = None
+ if not self.reference_free:
+ if self.ref_model is not None:
+ unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
+ ref_lm_head = unwrapped_ref_model.get_output_embeddings()
+ else:
+ with self.null_ref_context():
+ ref_lm_head = unwrapped_model.get_output_embeddings()
+ ref_weight = ref_lm_head.weight
+ ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None
+
+ # Compute loss using Liger kernel
+ loss_output = self.dpo_loss_fn(
+ lm_head.weight,
+ hidden_states,
+ labels,
+ bias=lm_head.bias if hasattr(lm_head, "bias") else None,
+ ref_input=ref_hidden_states if not self.reference_free else None,
+ ref_weight=ref_weight if not self.reference_free else None,
+ ref_bias=ref_bias if not self.reference_free else None,
+ )
+ (
+ loss,
+ (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs),
+ ) = loss_output
+
+ output = {
+ "loss": loss,
+ "chosen_logps": chosen_logps,
+ "rejected_logps": rejected_logps,
+ "mean_chosen_logits": chosen_logits_mean,
+ "mean_rejected_logits": rejected_logits_mean,
+ "nll_loss": nll_loss,
+ "chosen_rewards": aux_outputs[0],
+ "rejected_rewards": aux_outputs[1],
+ }
+ if self.aux_loss_enabled:
+ output["aux_loss"] = outputs.aux_loss
+
+ return output
+
+ def concatenated_forward(
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False
+ ) -> dict[str, torch.Tensor]:
+ """
+ Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
+
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
+
+ Args:
+ model:
+ Model to run the forward pass on.
+ batch:
+ Batch of input data.
+ is_ref_model:
+ Whether this method is being called for the reference model. If `True`, length desensitization is not
+ applied.
+ """
+ num_examples = batch["prompt_input_ids"].shape[0]
+
+ concatenated_batch = self.concatenated_inputs(batch, padding_value=self.pad_token_id)
+
+ model_kwargs = {"use_cache": False}
+ if self.aux_loss_enabled:
+ model_kwargs["output_router_logits"] = True
+
+ # Add the pixel values and attention masks for vision models
+ if "pixel_values" in concatenated_batch:
+ model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
+ if "pixel_attention_mask" in concatenated_batch:
+ model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
+ if "image_sizes" in concatenated_batch:
+ model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
+
+ prompt_input_ids = concatenated_batch["prompt_input_ids"]
+ prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
+ completion_input_ids = concatenated_batch["completion_input_ids"]
+ completion_attention_mask = concatenated_batch["completion_attention_mask"]
+ if self.is_encoder_decoder:
+ labels = completion_input_ids
+ labels[completion_attention_mask == 0] = self.label_pad_token_id
+ outputs = model(
+ input_ids=prompt_input_ids,
+ attention_mask=prompt_attention_mask,
+ labels=labels, # we need the labels for the logits to be returned
+ **model_kwargs,
+ )
+ logits = outputs.logits
+ loss_mask = completion_attention_mask.bool()
+ else:
+ # Concatenate the prompt and completion inputs
+ input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
+ attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
+ if "token_type_ids" in concatenated_batch:
+ prompt_token_type_ids = concatenated_batch["token_type_ids"]
+ token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0)
+ # Mask the prompt but not the completion for the loss
+ loss_mask = torch.cat(
+ (torch.zeros_like(prompt_attention_mask), completion_attention_mask),
+ dim=1,
+ )
+
+ # Flush and truncate
+ if self.max_length is not None and self.max_length < attention_mask.size(1):
+ if self.truncation_mode == "keep_start":
+ # Flush left to reduce the memory usage
+ # [[0, 0, x, x, x, x], -> [[x, x, x, x],
+ # [0, x, x, x, 0, 0]] [x, x, x, 0]]
+ if "token_type_ids" in concatenated_batch:
+ attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
+ attention_mask, input_ids, loss_mask, token_type_ids
+ )
+ else:
+ attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
+ attention_mask = attention_mask[:, : self.max_length]
+ input_ids = input_ids[:, : self.max_length]
+ loss_mask = loss_mask[:, : self.max_length]
+ elif self.truncation_mode == "keep_end":
+ # Flush right before truncating left, then flush left
+ # [[0, 0, x, x, x, x], -> [[0, 0, x, x],
+ # [0, x, x, x, 0, 0]] [0, x, x, x]]
+ if "token_type_ids" in concatenated_batch:
+ attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
+ attention_mask, input_ids, loss_mask, token_type_ids
+ )
+ token_type_ids = token_type_ids[:, -self.max_length :]
+ else:
+ attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
+ input_ids = input_ids[:, -self.max_length :]
+ attention_mask = attention_mask[:, -self.max_length :]
+ loss_mask = loss_mask[:, -self.max_length :]
+ if "token_type_ids" in concatenated_batch:
+ attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
+ attention_mask, input_ids, loss_mask, token_type_ids
+ )
+ else:
+ attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
+ else:
+ raise ValueError(
+ f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
+ "'keep_start']."
+ )
+ else:
+ # Flush left to reduce the memory usage
+ # [[0, 0, x, x, x, x], -> [[x, x, x, x],
+ # [0, x, x, x, 0, 0]] [x, x, x, 0]]
+ if "token_type_ids" in concatenated_batch:
+ attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
+ attention_mask, input_ids, loss_mask, token_type_ids
+ )
+ else:
+ attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
+
+ if "token_type_ids" in concatenated_batch:
+ model_kwargs["token_type_ids"] = token_type_ids
+
+ if self.use_logits_to_keep:
+ # Compute logits_to_keep based on loss_mask pattern:
+ # [[0, 0, 0, x, x, x, x],
+ # [0, 0, 0, x, x, x, 0]]
+ # ^ start computing logits from here ([:, -(7-3+1):])
+ first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
+ logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label
+ model_kwargs["logits_to_keep"] = logits_to_keep
+
+ model_kwargs["output_hidden_states"] = True
+
+ if self.padding_free:
+ # Flatten the input_ids, position_ids, and loss_mask
+ # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]]
+ # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]]
+ input_ids = input_ids[attention_mask.bool()].unsqueeze(0)
+ loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0)
+ position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1
+ model_kwargs["position_ids"] = position_ids
+ else:
+ model_kwargs["attention_mask"] = attention_mask
+
+ outputs = model(input_ids, **model_kwargs)
+ logits = outputs.logits
+
+ # Offset the logits by one to align with the labels
+ labels = torch.roll(input_ids, shifts=-1, dims=1)
+ loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool()
+
+ if self.use_logits_to_keep:
+ # Align labels with logits
+ # logits: -, -, [x2, x3, x4, x5, x6]
+ # ^ --------- ^ after logits[:, :-1, :]
+ # labels: [y0, y1, y2, y3, y4, y5, y6]
+ # ^ --------- ^ with logits_to_keep=4, [:, -4:]
+ # loss_mask: [0, 0, 0, 1, 1, 1, 1]
+ labels = labels[:, -logits_to_keep:]
+ loss_mask = loss_mask[:, -logits_to_keep:]
+
+ if logits.shape[:2] != labels.shape[:2]:
+ # for LLaVA, the returned logits include the image tokens (placed before the text tokens)
+ seq_len = labels.shape[1]
+ logits = logits[:, -seq_len:]
+
+ # Compute the log probabilities of the labels
+ labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
+ per_token_logps = selective_log_softmax(logits, labels)
+ per_token_logps[~loss_mask] = 0
+ per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)
+
+ if self.padding_free:
+ # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len])
+ batch_size, seq_len = attention_mask.shape
+ per_token_logps_ = torch.zeros(
+ batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype
+ )
+ per_token_logps_[attention_mask.bool()] = per_token_logps
+ per_token_logps = per_token_logps_
+
+ all_logps = per_token_logps[:, 1:].sum(-1)
+
+ output = {}
+
+ if self.use_weighting:
+ with torch.no_grad():
+ # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827
+ logprobs = F.log_softmax(logits, dim=-1)
+ weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space
+ per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
+ all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)
+ chosen_weights = all_weights[:num_examples]
+ rejected_weights = all_weights[num_examples:]
+ output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1)
+
+ if self.args.rpo_alpha is not None or "sft" in self.loss_type:
+ # Only use the chosen logits for the RPO loss or SFT loss
+ chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples]
+ chosen_labels = labels[:num_examples, :-1] if not self.is_encoder_decoder else labels[:num_examples]
+
+ # Compute the log probabilities of the labels
+ output["nll_loss"] = F.cross_entropy(
+ torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0
+ )
+
+ if "ipo" in self.loss_type:
+ all_logps = all_logps / loss_mask.sum(-1)
+
+ if self.args.ld_alpha is not None and not is_ref_model:
+ # Compute response lengths based on loss_mask
+ completion_lengths = loss_mask.sum(dim=1)
+
+ chosen_lengths = completion_lengths[:num_examples]
+ rejected_lengths = completion_lengths[num_examples:]
+ public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper
+ public_lengths = torch.cat([public_lengths, public_lengths], dim=0)
+
+ seq_len = per_token_logps.size(1)
+ position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
+
+ ld_mask = position_ids < public_lengths.unsqueeze(1)
+ mask = position_ids < completion_lengths.unsqueeze(1)
+
+ front_mask = (ld_mask & mask).float()
+ rear_mask = (~ld_mask & mask).float()
+ front_logps = (per_token_logps * front_mask).sum(dim=1)
+ rear_logps = (per_token_logps * rear_mask).sum(dim=1)
+
+ all_logps = front_logps + self.args.ld_alpha * rear_logps
+
+ output["chosen_logps"] = all_logps[:num_examples]
+ output["rejected_logps"] = all_logps[num_examples:]
+
+ # Compute the mean logits
+ if self.padding_free:
+ # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]).
+ # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens,
+ # and the second half to the rejected tokens.
+ # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id.
+ split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples]
+ mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean()
+ mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean()
+ else:
+ mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean()
+ mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean()
+
+ output["mean_chosen_logits"] = mean_chosen_logits
+ output["mean_rejected_logits"] = mean_rejected_logits
+
+ if self.aux_loss_enabled:
+ output["aux_loss"] = outputs.aux_loss
+
+ return output
+
+ def get_batch_loss_metrics(
+ self,
+ model: Union[PreTrainedModel, nn.Module],
+ batch: dict[str, Union[list, torch.LongTensor]],
+ train_eval: Literal["train", "eval"] = "train",
+ ) -> tuple[torch.Tensor, dict[str, float]]:
+ """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
+ metrics = {}
+
+ if self.args.use_liger_loss:
+ model_output = self._compute_loss_liger(model, batch)
+ losses = model_output["loss"]
+ chosen_rewards = model_output["chosen_rewards"]
+ rejected_rewards = model_output["rejected_rewards"]
+ else:
+ model_output = self.concatenated_forward(model, batch)
+
+ # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model
+ if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
+ ref_chosen_logps = batch["ref_chosen_logps"]
+ ref_rejected_logps = batch["ref_rejected_logps"]
+ else:
+ ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
+
+ # Initialize combined losses
+ losses = 0
+ chosen_rewards = 0
+ rejected_rewards = 0
+
+ # Compute losses for each loss type
+ for idx, loss_type in enumerate(self.loss_type):
+ # Compute individual loss using standard DPO loss function
+ _losses, _chosen_rewards, _rejected_rewards = self.dpo_loss(
+ model_output["chosen_logps"],
+ model_output["rejected_logps"],
+ ref_chosen_logps,
+ ref_rejected_logps,
+ loss_type,
+ model_output,
+ )
+
+ # Add weighted contributions
+ weight = self.loss_weights[idx] if self.loss_weights else 1.0
+ losses = losses + _losses * weight
+ chosen_rewards = chosen_rewards + _chosen_rewards * weight
+ rejected_rewards = rejected_rewards + _rejected_rewards * weight
+
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
+
+ if self.args.rpo_alpha is not None:
+ losses = losses + self.args.rpo_alpha * model_output["nll_loss"] # RPO loss from V3 of the paper
+
+ if self.use_weighting:
+ losses = losses * model_output["policy_weights"]
+
+ if self.aux_loss_enabled:
+ losses = losses + self.aux_loss_coef * model_output["aux_loss"]
+
+ prefix = "eval_" if train_eval == "eval" else ""
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
+ metrics[f"{prefix}rewards/margins"] = (
+ self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
+ )
+ metrics[f"{prefix}logps/chosen"] = (
+ self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item()
+ )
+ metrics[f"{prefix}logps/rejected"] = (
+ self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item()
+ )
+ metrics[f"{prefix}logits/chosen"] = (
+ self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item()
+ )
+ metrics[f"{prefix}logits/rejected"] = (
+ self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item()
+ )
+ if self.args.rpo_alpha is not None or "sft" in self.loss_type:
+ metrics[f"{prefix}nll_loss"] = (
+ self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item()
+ )
+ if self.aux_loss_enabled:
+ metrics[f"{prefix}aux_loss"] = (
+ self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item()
+ )
+
+ return losses.mean(), metrics
+
+ def compute_loss(
+ self,
+ model: Union[PreTrainedModel, nn.Module],
+ inputs: dict[str, Union[torch.Tensor, Any]],
+ return_outputs=False,
+ num_items_in_batch=None,
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]:
+ compute_loss_context_manager = (
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
+ )
+ with compute_loss_context_manager:
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
+
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
+ loss = loss.to(self.args.device)
+ # force log the metrics
+ self.store_metrics(metrics, train_eval="train")
+
+ if return_outputs:
+ return loss, metrics
+
+ return loss
+
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
+ """Generate samples from the model and reference model for the given batch of inputs."""
+
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
+ # the torch amp context manager as some hidden states are silently casted to full precision.
+ generate_context_manager = (
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
+ )
+
+ with generate_context_manager:
+ policy_output = model.generate(
+ input_ids=batch["prompt_input_ids"],
+ attention_mask=batch["prompt_attention_mask"],
+ max_length=self.max_length,
+ do_sample=True,
+ pad_token_id=self.pad_token_id,
+ )
+
+ # if ref_output in batch use that otherwise use the reference model
+ if "ref_output" in batch:
+ ref_output = batch["ref_output"]
+ else:
+ if self.ref_model is None:
+ with self.null_ref_context():
+ ref_output = self.model.generate(
+ input_ids=batch["prompt_input_ids"],
+ attention_mask=batch["prompt_attention_mask"],
+ max_length=self.max_length,
+ do_sample=True,
+ pad_token_id=self.pad_token_id,
+ )
+ else:
+ ref_output = self.ref_model.generate(
+ input_ids=batch["prompt_input_ids"],
+ attention_mask=batch["prompt_attention_mask"],
+ max_length=self.max_length,
+ do_sample=True,
+ pad_token_id=self.pad_token_id,
+ )
+
+ policy_output = pad_to_length(policy_output, self.max_length, self.pad_token_id)
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
+
+ ref_output = pad_to_length(ref_output, self.max_length, self.pad_token_id)
+ ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True)
+
+ return policy_output_decoded, ref_output_decoded
+
+ def prediction_step(
+ self,
+ model: Union[PreTrainedModel, nn.Module],
+ inputs: dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[list[str]] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ if ignore_keys is None:
+ if hasattr(model, "config"):
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
+ else:
+ ignore_keys = []
+
+ prediction_context_manager = (
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
+ )
+
+ with torch.no_grad(), prediction_context_manager:
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
+
+ # force log the metrics
+ self.store_metrics(metrics, train_eval="eval")
+
+ if prediction_loss_only:
+ return loss.detach(), None, None
+
+ # logits for the chosen and rejected samples from model
+ logits_dict = {
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
+ }
+ logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
+ logits = torch.tensor(logits, device=self.accelerator.device)
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
+
+ return (loss.detach(), logits, labels)
+
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
+ for key, value in metrics.items():
+ self._stored_metrics[train_eval][key].append(value)
+
+ def evaluation_loop(
+ self,
+ dataloader: DataLoader,
+ description: str,
+ prediction_loss_only: Optional[bool] = None,
+ ignore_keys: Optional[list[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> EvalLoopOutput:
+ """
+ Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
+ `Trainer.evaluate()` and `Trainer.predict()`.
+
+ Works both with or without labels.
+ """
+
+ # Sample and save to game log if requested (for one batch to save time)
+ if self.generate_during_eval:
+ # Generate random indices within the range of the total number of samples
+ num_samples = len(dataloader.dataset)
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
+
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
+ random_batch_dataset = dataloader.dataset.select(random_indices)
+ random_batch = self.data_collator(random_batch_dataset)
+ random_batch = self._prepare_inputs(random_batch)
+
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch)
+
+ table = pd.DataFrame(
+ columns=["Prompt", "Policy", "Ref Model"],
+ data=[
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
+ for prompt, pol, ref in zip(
+ random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded
+ )
+ ],
+ )
+ if "wandb" in self.args.report_to and self.accelerator.is_main_process:
+ wandb.log({"game_log": wandb.Table(data=table)})
+
+ if "comet_ml" in self.args.report_to:
+ log_table_to_comet_experiment(
+ name="game_log.csv",
+ table=table,
+ )
+
+ if "mlflow" in self.args.report_to and self.accelerator.is_main_process:
+ mlflow.log_table(data=table, artifact_file="game_log.json")
+
+ # Base evaluation
+ initial_output = super().evaluation_loop(
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
+ )
+
+ return initial_output
+
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
+ """
+ Log `logs` on the various objects watching training, including stored metrics.
+
+ Args:
+ logs (`dict[str, float]`):
+ The values to log.
+ start_time (`float`, *optional*):
+ Start time of the training.
+ """
+ # logs either has 'loss' or 'eval_loss'
+ train_eval = "train" if "loss" in logs else "eval"
+ # Add averaged stored metrics to logs
+ for key, metrics in self._stored_metrics[train_eval].items():
+ logs[key] = torch.tensor(metrics).mean().item()
+ del self._stored_metrics[train_eval]
+ return super().log(logs, start_time)
+
+ # Ensure the model card is saved along with the checkpoint
+ def _save_checkpoint(self, model, trial):
+ if self.args.hub_model_id is None:
+ model_name = Path(self.args.output_dir).name
+ else:
+ model_name = self.args.hub_model_id.split("/")[-1]
+ self.create_model_card(model_name=model_name)
+ super()._save_checkpoint(model, trial)
+class UnslothDPOTrainer(_UnslothDPOTrainer):
+ """
+
+ Trainer for Direct Preference Optimization (DPO) method.
+
+ This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
+
+ Args:
+ model (`Union[str, PreTrainedModel]`):
+ Model to be trained. Can be either:
+
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
+ path to a *directory* containing model weights saved using
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
+ using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
+ `args.model_init_kwargs`.
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
+ ref_model ([`PreTrainedModelWrapper`]):
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
+ and loss. If no reference model is provided, the trainer will create a reference model with the same
+ architecture as the model to be optimized.
+ args ([`DPOConfig`], *optional*):
+ Configuration for this trainer. If `None`, a default configuration is used.
+ data_collator ([`~transformers.DataCollator`], *optional*):
+ Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
+ Will default to [`DataCollatorForPreference`].
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
+ Dataset to use for training. DPO supports [preference](#preference) type and. The format of the samples can
+ be either:
+
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
+ and content).
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
+ Processing class used to process the data. If `None`, the processing class is loaded from the model's name
+ with [`~transformers.AutoTokenizer.from_pretrained`].
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
+ The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
+ a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
+ `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
+ after the last eval batch to signal that the function needs to calculate and return the global summary
+ statistics rather than accumulating the batch-level statistics.
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*):
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
+ in [here](https://huggingface.co/docs/transformers/main_classes/callback).
+
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
+ method.
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
+ optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
+ A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
+ `args`. Incompatible with the `optimizers` argument.
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
+ A function that preprocess the logits right before caching them at each evaluation step. Must take two
+ tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
+ by this function will be reflected in the predictions received by `compute_metrics`.
+
+ Note that the labels (second parameter) will be `None` if the dataset does not have them.
+ peft_config ([`~peft.PeftConfig`], *optional*):
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
+
+ """
+ def __init__(
+ self,
+ model,
+ ref_model = None,
+ args = None,
+ data_collator = None,
+ train_dataset = None,
+ eval_dataset = None,
+ processing_class = None,
+ compute_metrics = None,
+ callbacks = None,
+ optimizer_cls_and_kwargs = None,
+ preprocess_logits_for_metrics = None,
+ peft_config = None,
+ **kwargs
+ ):
+ if args is None: args = UnslothDPOConfig()
+ use_bf16 = getattr(args, 'bf16', False)
+ if type(use_bf16) is not bool: use_bf16 = False
+ use_fp16 = getattr(args, 'fp16', False)
+ if type(use_fp16) is not bool: use_fp16 = False
+ force_float32 = False
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
+ force_float32 = True
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
+ from unsloth_zoo.utils import _get_dtype
+ dtype = _get_dtype(dtype)
+ float16 = dtype == torch.float16
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
+ if force_float32:
+ # Forced float32 training
+ args.fp16 = False
+ args.bf16 = False
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
+ # args.mixed_precision is a new argument which needs to be set now
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
+ # Mixed precision training
+ args.fp16 = float16
+ args.bf16 = not float16
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
+ # args.mixed_precision is a new argument which needs to be set now
+ elif mixed_precision_dtype == 'bfloat16':
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
+ args.fp16 = False
+ args.bf16 = False
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
+ # args.mixed_precision is a new argument which needs to be set now
+
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
+ args.eval_strategy = 'steps'
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
+ if ga_steps is not None and ga_steps > 1:
+ from transformers import __version__ as transformers_version
+ if Version(transformers_version) <= Version('4.45.2'):
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
+ if getattr(args, 'eval_strategy', 'no') != 'no':
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
+ if force_float32:
+ args.bf16_full_eval = False
+ args.fp16_full_eval = False
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
+ args.bf16_full_eval = True
+ args.fp16_full_eval = False
+ elif not bf16_full_eval and not fp16_full_eval:
+ args.bf16_full_eval = args.bf16
+ args.fp16_full_eval = args.fp16
+ _output_logits = False
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
+ if _output_logits:
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
+ if model is not None:
+ _warnings_issued = getattr(model, 'warnings_issued', None)
+ if _warnings_issued is None:
+ model.warnings_issued = {}
+ elif not isinstance(_warnings_issued, dict):
+ try:
+ model.warnings_issued = dict(_warnings_issued)
+ except Exception:
+ model.warnings_issued = {}
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
+ pass
+ else:
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
+ if args_max_seq_length is None and model_max_seq_length is not None:
+ max_seq_length = model.max_seq_length
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
+ elif args_max_seq_length is not None and model_max_seq_length is not None:
+ if args_max_seq_length > model_max_seq_length:
+ print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
+ 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
+ args.max_seq_length = model_max_seq_length
+ if model is not None and hasattr(model, 'for_training'):
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
+ if 'processing_class' in locals():
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
+ if not isinstance(data_collator, UnslothVisionDataCollator):
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
+ data_collator = TransformersDataCollatorForLanguageModeling(
+ __tokenizer,
+ mlm = False,
+ mlm_probability = 0.0,
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
+ )
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
+ data_collator = DataCollatorForSeq2Seq(
+ __tokenizer,
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
+ )
+ else:
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
+ if not isinstance(data_collator, UnslothVisionDataCollator):
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
+ data_collator = DataCollatorForSeq2Seq(
+ __tokenizer.tokenizer,
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
+ )
+ else:
+ data_collator = TransformersDataCollatorForLanguageModeling(
+ __tokenizer.tokenizer,
+ mlm = False,
+ mlm_probability = 0.0,
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
+ )
+ other_metrics = []
+
+ from unsloth_zoo.logging_utils import PatchRLStatistics
+ PatchRLStatistics('dpo_trainer', other_metrics)
+ if hasattr(train_dataset, 'column_names'):
+ column_names = set(train_dataset.column_names)
+ check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',
+ 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',
+ 'prompt_input_ids', 'prompt_attention_mask']
+ if all(x in column_names for x in check):
+ train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])
+ del check, column_names
+
+ # [TODO] Fix up DataParallel multiplying batch sizes
+ # [TODO] DDP works, but DP seems to not work? [TODO]
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
+ if getattr(args, "_n_gpu", 1) != 1:
+ args._n_gpu = 1
+ if "model" in locals() and hasattr(model, "for_training"):
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
+ super().__init__(
+ model = model,
+ ref_model = ref_model,
+ args = args,
+ data_collator = data_collator,
+ train_dataset = train_dataset,
+ eval_dataset = eval_dataset,
+ processing_class = processing_class,
+ compute_metrics = compute_metrics,
+ callbacks = callbacks,
+ optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
+ peft_config = peft_config,**kwargs)
+ if "model" in locals() and hasattr(model, "for_inference"):
+ model.for_inference()
+ if hasattr(self, 'neftune_hook_handle'):
+ self.neftune_hook_handle.remove()
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
+ pass
+ if hasattr(self, 'accelerator'):
+ scaler = self.accelerator.scaler
+ current_model = model
+ while hasattr(current_model, 'model'):
+ current_model.accelerator_scaler = scaler
+ current_model = current_model.model
+ current_model.accelerator_scaler = scaler
+ pass
+ if hasattr(self, 'train'):
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
+ pass
+ if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
+ _vllm_tok = self.llm.get_tokenizer()
+ _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
+ if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
+ _vllm_tok.chat_template = _pc.chat_template
+ pass
+
+pass
+
+
+if hasattr(logger, "addFilter"):
+ import logging
+ class HideLoggingMessage(logging.Filter):
+ def __init__(self, text): self.text = text
+ def filter(self, x): return not (self.text in x.getMessage())
+ pass
+ logger.addFilter(HideLoggingMessage("`use_cache=True`"))
+