# Copyright 2020-2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import hashlib import importlib.resources as pkg_resources import os import random import socket import threading import types from collections.abc import Mapping, Sequence, Sized from contextlib import contextmanager from importlib.metadata import version from itertools import accumulate from typing import TypeVar import numpy as np import pandas as pd import torch import torch.nn.functional as F import transformers from accelerate import PartialState, logging from huggingface_hub import ModelCard, ModelCardData from torch.utils.data import Sampler from transformers import ( AutoConfig, BitsAndBytesConfig, PretrainedConfig, PreTrainedModel, is_comet_available, is_trackio_available, ) from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.utils import ( is_peft_available, is_rich_available, is_torch_xpu_available, ) from ..trainer.model_config import ModelConfig if is_rich_available(): from rich.console import Console from rich.panel import Panel from rich.table import Table from rich.text import Text if is_comet_available(): import comet_ml if is_peft_available(): from peft import LoraConfig, PeftConfig, PeftModel logger = logging.get_logger(__name__) def _is_port_free(port: int, host: str = "127.0.0.1") -> bool: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind((host, port)) return True except OSError: return False def _find_free_port() -> int: candidates = (29500, 23456, 12355, 12345) for p in candidates: if _is_port_free(p): return p with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) return s.getsockname()[1] def ensure_master_addr_port(addr: str | None = None, port: int | None = None) -> None: """ Ensure `MASTER_ADDR`/`MASTER_PORT` are set safely. - Respects existing environment variables. - Defaults `MASTER_ADDR` to localhost if unset. - Chooses a free TCP port if `MASTER_PORT` is unset to avoid collisions. - If `MASTER_PORT` is set to `"0"` or `"auto"`, it is resolved to a free port. """ os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR") or addr or "localhost" env_port = os.environ.get("MASTER_PORT", "").strip().lower() if port is None and env_port not in {"", "0", "auto"}: try: port = int(env_port) except ValueError: pass os.environ["MASTER_PORT"] = str(_find_free_port() if port in (None, 0) else port) def pad( tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str = "right", pad_to_multiple_of: int | None = None, ) -> torch.Tensor: """ Pads a list of tensors to the same shape along the first dimension. Args: tensors (`list[torch.Tensor]`): List of input tensors to pad. padding_value (`int`): Value to use for padding. Default is 0. padding_side (`str`): Side on which to add padding. Must be 'left' or 'right'. Default is 'right'. pad_to_multiple_of (`int`, *optional*): If set will pad the sequence to a multiple of the provided value. Returns: `torch.Tensor`: A single tensor containing the padded tensors. Examples: ```python >>> import torch >>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])]) tensor([[1, 2, 3], [4, 5, 0]]) >>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])]) tensor([[[1, 2], [3, 4]], [[5, 6], [0, 0]]]) ``` """ # Determine the maximum shape for each dimension output_shape = np.max([t.shape for t in tensors], 0).tolist() # Apply pad_to_multiple_of to the first (sequence) dimension if pad_to_multiple_of is not None: remainder = output_shape[0] % pad_to_multiple_of if remainder != 0: output_shape[0] += pad_to_multiple_of - remainder # Create an output tensor filled with the padding value output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device) for i, t in enumerate(tensors): if padding_side == "left": seq_start = output_shape[0] - t.shape[0] elif padding_side == "right": seq_start = 0 else: raise ValueError("padding_side must be 'left' or 'right'") # Define the slices seq_slice = slice(seq_start, seq_start + t.shape[0]) slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:]) output[i][slices] = t return output def disable_dropout_in_model(model: torch.nn.Module) -> None: for module in model.modules(): if isinstance(module, torch.nn.Dropout): module.p = 0 def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConfig | None: if model_args.load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=model_args.dtype, # For consistency with model weights, we use the same value as `dtype` bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage, ) elif model_args.load_in_8bit: quantization_config = BitsAndBytesConfig( load_in_8bit=True, ) else: quantization_config = None return quantization_config def get_kbit_device_map() -> dict[str, int] | None: if torch.cuda.is_available() or is_torch_xpu_available(): return {"": PartialState().local_process_index} else: return None def get_peft_config(model_args: ModelConfig) -> "PeftConfig | None": if model_args.use_peft is False: return None if not is_peft_available(): raise ValueError( "You need to have PEFT library installed in your environment, make sure to install `peft`. " "Make sure to run `pip install -U peft`." ) peft_config = LoraConfig( task_type=model_args.lora_task_type, r=model_args.lora_r, target_modules=model_args.lora_target_modules, target_parameters=model_args.lora_target_parameters, lora_alpha=model_args.lora_alpha, lora_dropout=model_args.lora_dropout, bias="none", use_rslora=model_args.use_rslora, use_dora=model_args.use_dora, modules_to_save=model_args.lora_modules_to_save, ) return peft_config def generate_model_card( base_model: str | None, model_name: str, hub_model_id: str, dataset_name: str | None, tags: list[str], wandb_url: str | None, trackio_url: str | None, trainer_name: str, trainer_citation: str | None = None, template_file: str | None = None, paper_title: str | None = None, paper_id: str | None = None, comet_url: str | None = None, ) -> ModelCard: """ Generate a [`~huggingface_hub.ModelCard`] from a template. Args: base_model (`str` or `None`): Base model name. model_name (`str`): Model name. hub_model_id (`str`): Hub model ID as `username/model_id`. dataset_name (`str` or `None`): Dataset name. tags (`list[str]`): Tags. wandb_url (`str` or `None`): Weights & Biases run URL. trackio_url (`str` or `None`): Trackio Space URL. comet_url (`str` or `None`): Comet experiment URL. trainer_name (`str`): Trainer name. trainer_citation (`str` or `None`, defaults to `None`): Trainer citation as a BibTeX entry. template_file (`str` *optional*): Template file name located in the `trl/templates` directory. Defaults to `lm_model_card.md`. paper_title (`str` or `None`, defaults to `None`): Paper title. paper_id (`str` or `None`, defaults to `None`): ArXiv paper ID as `YYMM.NNNNN`. Returns: [`~huggingface_hub.ModelCard`]: A ModelCard object. """ card_data = ModelCardData( base_model=base_model, datasets=dataset_name, library_name="transformers", licence="license", model_name=model_name, tags=["generated_from_trainer", *tags], ) template_file = template_file or "lm_model_card.md" card = ModelCard.from_template( card_data, template_path=str(pkg_resources.files("trl").joinpath(f"templates/{template_file}")), base_model=base_model, model_name=model_name, hub_model_id=hub_model_id, dataset_name=dataset_name, wandb_url=wandb_url, trackio_url=trackio_url, comet_url=comet_url, trainer_name=trainer_name, trainer_citation=trainer_citation, paper_title=paper_title, paper_id=paper_id, trl_version=version("trl"), transformers_version=version("transformers"), pytorch_version=version("torch"), datasets_version=version("datasets"), tokenizers_version=version("tokenizers"), ) return card def get_comet_experiment_url() -> str | None: """ If Comet integration is enabled, return the URL of the current Comet experiment; otherwise, return `None`. """ if not is_comet_available(): return None if comet_ml.get_running_experiment() is not None: return comet_ml.get_running_experiment().url return None def get_trackio_space_url() -> str | None: """ If Trackio integration is enabled, return the URL of the current Trackio Space; otherwise, return `None`. """ if not is_trackio_available(): return None from trackio import context_vars run = context_vars.current_run.get() if run is None: return None space_id = run._space_id if space_id is None: return None space_id = space_id.replace("/", "-") project = run.project name = run.name return f"https://{space_id}.hf.space?project={project}&runs={name}&sidebar=collapsed" def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None: """ If Comet integration is enabled logs a table to the Comet experiment if it is currently running. Args: name (`str`): Table name. table (`pandas.DataFrame`): The Pandas DataFrame containing the table to log. """ if not is_comet_available(): raise ModuleNotFoundError("The comet-ml is not installed. Please install it first: pip install comet-ml") experiment = comet_ml.get_running_experiment() if experiment is not None: experiment.log_table(tabular_data=table, filename=name) def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Shift non-zero elements in the mask and corresponding tensors to the left. This function operates on a binary mask and any number of additional tensors with the same dimensions as the mask. For each row, non-zero values are shifted to the leftmost positions. Then, columns that contain only zeros across all rows are truncated from the mask and tensors. Visually, this operation can be represented as follows: ``` [[0, 0, x, x, x, x], -> [[x, x, x, x], [0, x, x, x, 0, 0]] [x, x, x, 0]] ``` Args: mask (`torch.Tensor`): 2D tensor (binary mask) with shape `(N, M)`. *tensors (`torch.Tensor`): One or more 2D tensors with the same shape as `mask`. These tensors will be processed alongside `mask`, with non-zero values shifted and excess zero columns truncated in the same manner. Returns: `torch.Tensor`: Updated binary mask with non-zero values flushed to the left and trailing zero columns removed. `*torch.Tensor` Updated tensors, processed in the same way as the mask. Example: ```python >>> mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]]) >>> tensor = torch.tensor([[9, 9, 2, 3, 4], [9, 5, 6, 9, 9]]) >>> new_mask, new_tensor = flush_left(mask, tensor) >>> print(new_mask) tensor([[1, 1, 1], [1, 1, 0]]) >>> print(new_tensor) tensor([[2, 3, 4], [5, 6, 0]]) ``` """ _, M = mask.shape # Create copy of mask and tensors mask_copy = mask.clone() tensors = [t.clone() for t in tensors] # Shift non-zero values to the left first_non_zero = mask_copy.argmax(dim=1) pos = torch.arange(M, device=mask_copy.device).unsqueeze(0) idx_roll = (pos + first_non_zero.unsqueeze(1)) % M mask_roll = mask_copy.gather(1, idx_roll) rolled_tensors = [t.gather(1, idx_roll) for t in tensors] # Truncate trailing columns that are all zeros in mask_roll col_sums = mask_roll.sum(dim=0) empty_cols = col_sums == 0 first_empty_col = int(empty_cols.to(torch.int8).argmax()) if empty_cols.any() else M flushed_mask = mask_roll[:, :first_empty_col] flushed_tensors = [t[:, :first_empty_col] for t in rolled_tensors] if not flushed_tensors: return flushed_mask return flushed_mask, *flushed_tensors def selective_log_softmax(logits, index) -> torch.Tensor: """ A memory-efficient implementation of the common `log_softmax -> gather` operation. This function is equivalent to the following naive implementation: ```python # for index with shape (...): logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) # for index with shape (..., K): logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index) ``` Args: logits (`torch.Tensor`): Logits tensor of shape `(..., num_classes)`. index (`torch.Tensor`): Index tensor of shape `(..., K)` or `(...)`, specifying the positions to gather from the log-softmax output. When the last case is used, `K` log-probabilities are gathered per position (e.g. for top-K) Returns: `torch.Tensor`: Gathered log probabilities with the same shape as `index`. """ squeeze = index.ndim == logits.ndim - 1 if squeeze: index = index.unsqueeze(-1) if logits.dtype in [torch.float32, torch.float64]: selected_logits = torch.gather(logits, dim=-1, index=index) # loop to reduce peak mem consumption logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) per_token_logps = selected_logits - logsumexp_values.unsqueeze(-1) # log_softmax(x_i) = x_i - logsumexp(x) else: # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach per_token_logps = [] for row_logits, row_labels in zip(logits, index, strict=True): # loop to reduce peak mem consumption row_logps = F.log_softmax(row_logits, dim=-1) row_per_token_logps = row_logps.gather(dim=-1, index=row_labels) per_token_logps.append(row_per_token_logps) per_token_logps = torch.stack(per_token_logps) if squeeze: per_token_logps = per_token_logps.squeeze(-1) return per_token_logps def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor: """ Compute the Shannon entropy (in nats) for each row of *logits* in a memory-efficient way. Instead of materializing the full softmax for all rows at once, the logits are flattened to shape (N, num_classes), where N is the product of all leading dimensions. Computation is then performed in chunks of size `chunk_size` along this flattened dimension, reducing peak memory usage. The result is reshaped back to match the input's leading dimensions. Args: logits (`torch.Tensor`): Logits tensor of shape `(..., num_classes)`. Entropy is taken along the last axis; all leading dimensions are preserved in the output. chunk_size (`int`, *optional*, defaults to `128`): Number of rows from the flattened logits to process per iteration. Smaller values reduce memory usage at the cost of more iterations. Returns: `torch.Tensor`: Entropy values with shape `logits.shape[:-1]`. """ original_shape = logits.shape[:-1] # all dims except num_classes num_classes = logits.shape[-1] # Flatten all leading dimensions into one flat_logits = logits.reshape(-1, num_classes) entropies = [] for chunk in flat_logits.split(chunk_size, dim=0): logps = F.log_softmax(chunk, dim=-1) chunk_entropy = -(torch.exp(logps) * logps).sum(-1) entropies.append(chunk_entropy) entropies = torch.cat(entropies, dim=0) return entropies.reshape(original_shape) def print_prompt_completions_sample( prompts: list, completions: list, rewards: dict[str, list[float]], advantages: list[float], step: int, num_samples: int = None, extra: dict[str, list] | None = None, ) -> None: """ Print out a sample of model completions to the console with multiple reward metrics. This function creates a nicely formatted table showing prompt-completion pairs, useful for monitoring model outputs during training. It requires the `rich` library to be installed. Args: prompts (`list`): List of prompts. Can be either strings or lists of messages. completions (`list`): List of completions corresponding to the prompts. Can be either strings or lists of messages. rewards (`dict[str, list[float]]`): Dictionary where keys are reward names and values are lists of rewards. advantages (`list[float]`): List of advantages corresponding to the prompts and completions. step (`int`): Current training step number, used in the output title. num_samples (`int`, *optional*): Number of random samples to display. If `None` (default), all items will be displayed. extra (`dict[str, list]`, *optional*): Additional columns to display after the advantage column. Keys are column names and values are lists of per-completion data (strings or any value convertible to string). Typically populated via `log_extra` in reward functions. If `None` (default), no extra columns are shown. Example: ```python >>> from trl.trainer.utils import print_prompt_completions_sample >>> prompts = ["The sky is", "The sun is"] >>> completions = [" blue.", " in the sky."] >>> rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} >>> advantages = [0.987, 0.654] >>> extra = {"source": ["dataset_A", "dataset_B"]} >>> print_prompt_completions_sample(prompts, completions, rewards, advantages, 42, extra=extra) ╭────────────────────────────────── Step 42 ───────────────────────────────────╮ │ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┓ │ │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ source ┃ │ │ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━┩ │ │ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ dataset_A │ │ │ ├────────────┼──────────────┼─────────────┼────────┼───────────┼───────────┤ │ │ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ dataset_B │ │ │ └────────────┴──────────────┴─────────────┴────────┴───────────┴───────────┘ │ ╰──────────────────────────────────────────────────────────────────────────────╯ ``` """ if not is_rich_available(): raise ImportError( "The function `print_prompt_completions_sample` requires the `rich` library. Please install it with " "`pip install rich`." ) console = Console() table = Table(show_header=True, header_style="bold white", expand=True) extra = extra or {} # Add columns table.add_column("Prompt", style="bright_yellow") table.add_column("Completion", style="bright_green") for reward_name in rewards.keys(): table.add_column(reward_name, style="bold cyan", justify="right") table.add_column("Advantage", style="bold magenta", justify="right") for extra_name in extra.keys(): table.add_column(extra_name, style="bright_white") def format_entry(entry) -> Text: t = Text() if isinstance(entry, list) and all(isinstance(m, dict) for m in entry): for j, msg in enumerate(entry): role = msg.get("role", "") if "content" in msg or "reasoning_content" in msg or "thinking" in msg: # Chat message t.append(f"{role.upper()}\n", style="bold red") reasoning = msg.get("reasoning_content") or msg.get("thinking") if reasoning: t.append(reasoning, style="italic dim white") t.append("\n") if "content" in msg: t.append(msg["content"]) elif "name" in msg and "args" in msg: # Tool call t.append(f"{role.upper()}\n", style="bold red") t.append(f"{msg['name']}({msg['args']})") else: # Fallback t.append(str(msg)) if j < len(entry) - 1: t.append("\n\n") else: t.append(str(entry)) return t # Some basic input validation if num_samples is not None: if num_samples >= len(prompts): num_samples = None elif num_samples <= 0: return # Subsample data if num_samples is specified if num_samples is not None: indices = random.sample(range(len(prompts)), num_samples) prompts = [prompts[i] for i in indices] completions = [completions[i] for i in indices] rewards = {key: [val[i] for i in indices] for key, val in rewards.items()} advantages = [advantages[i] for i in indices] extra = {key: [val[i] for i in indices] for key, val in extra.items()} for i in range(len(prompts)): reward_values = [f"{rewards[key][i]:.2f}" for key in rewards.keys()] # 2 decimals extra_values = [format_entry(extra[key][i]) for key in extra.keys()] table.add_row( format_entry(prompts[i]), format_entry(completions[i]), *reward_values, f"{advantages[i]:.2f}", *extra_values, ) table.add_section() # Adds a separator between rows panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white") console.print(panel) class RepeatSampler(Sampler): """ Sampler that repeats the indices of a dataset in a structured manner. Args: data_source (`Sized`): Dataset to sample from. mini_repeat_count (`int`): Number of times to repeat each index per batch. batch_size (`int`, *optional*, defaults to `1`): Number of unique indices per batch. repeat_count (`int`, *optional*, defaults to `1`): Number of times to repeat the full sampling process. shuffle (`bool`, *optional*, defaults to `True`): Whether to shuffle the dataset. seed (`int`, *optional*): Random seed for reproducibility (only affects this sampler). Example: ```python >>> sampler = RepeatSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4) >>> list(sampler) [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6] ``` ```txt mini_repeat_count = 3 - - - [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, | repeat_count = 2 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] | --------- --------- --------- --------- --------- --------- --------- --------- --------- --------- --------- --------- batch_size = 12 ``` """ def __init__( self, data_source: Sized, mini_repeat_count: int, batch_size: int = 1, repeat_count: int = 1, shuffle: bool = True, seed: int | None = None, ): self.data_source = data_source self.mini_repeat_count = mini_repeat_count self.batch_size = batch_size self.repeat_count = repeat_count self.num_samples = len(data_source) self.shuffle = shuffle self.seed = seed if shuffle: self.generator = torch.Generator() # Create a local random generator if seed is not None: self.generator.manual_seed(seed) def __iter__(self): if self.shuffle: # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7) indexes = torch.randperm(self.num_samples, generator=self.generator).tolist() else: indexes = list(range(self.num_samples)) # [2, 4, 3, 1, 0, 6, 5] # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3) indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)] # [[2, 4, 3], [1, 0, 6], [5]] # -> [[2, 4, 3], [1, 0, 6]] indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size] for chunk in indexes: for _ in range(self.repeat_count): for index in chunk: for _ in range(self.mini_repeat_count): yield index def __len__(self) -> int: return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count # torch.nanstd doesn't exist, so we define it here def nanstd(tensor: torch.Tensor, dim: int | tuple[int, ...] | None = None, keepdim: bool = False) -> torch.Tensor: """ Compute the standard deviation of a tensor, ignoring NaNs. Args: tensor (`torch.Tensor`): Input tensor. dim (`int` or `tuple[int, ...]`, *optional*): Dimension(s) to reduce. Defaults to all dimensions. keepdim (`bool`, *optional*, defaults to `False`): Whether to keep reduced dimensions. Returns: `torch.Tensor`: Standard deviation of the tensor, ignoring NaNs. """ # Compute variance ignoring NaNs mean = torch.nanmean(tensor, dim=dim, keepdim=True) variance = torch.nanmean((tensor - mean) ** 2, dim=dim, keepdim=True) count = torch.sum(~torch.isnan(tensor), dim=dim, keepdim=True) # count of non-NaN values correction = count / (count - 1) correction = torch.where(count > 1, correction, torch.full_like(correction, float("nan"))) variance *= correction # Bessel's correction std = torch.sqrt(variance) if keepdim: return std if dim is None: return std.squeeze() if isinstance(dim, int): return std.squeeze(dim) dims = [(d if d >= 0 else d + std.ndim) for d in dim] for d in sorted(dims, reverse=True): std = std.squeeze(d) return std def split_tensor_dict( tensor_dict: dict[str, torch.Tensor | None], num_chunks: int ) -> list[dict[str, torch.Tensor | None]]: """ Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts. Example: ```python >>> x = torch.arange(12).reshape(6, 2) >>> y = torch.arange(6).reshape(6, 1) >>> tensor_dict = {"x": x, "y": y} >>> split_tensor_dict(tensor_dict, 3) [ {"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])}, {"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])}, {"x": tensor([[ 8, 9], [10, 11]]), "y": tensor([[4], [5]])} ] ``` """ first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) chunk_size = first_tensor.shape[0] // num_chunks chunks = [] for i in range(num_chunks): chunk_dict = {} for key, tensor in tensor_dict.items(): if tensor is not None and (isinstance(tensor, list) or tensor.ndim > 0): chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size] elif tensor is not None and tensor.ndim == 0: chunk_dict[key] = tensor else: chunk_dict[key] = None chunks.append(chunk_dict) return chunks def shuffle_sequence_dict(seq_dict: dict[str, Sequence | None]) -> dict[str, Sequence | None]: """ Shuffles all sequence-like values in a dictionary along the first dimension in unison. Example: ```python >>> x = torch.arange(6).reshape(3, 2) >>> y = ["a", "b", "c"] >>> seq_dict = {"x": x, "y": y} >>> shuffle_sequence_dict(seq_dict) {'x': tensor([[2, 3], [0, 1], [4, 5]]), 'y': ['b', 'a', 'c']} ``` """ # Determine batch size from the first non-None sequence batch_size = len(next(v for v in seq_dict.values() if v is not None)) permutation = torch.randperm(batch_size) def permute(v: Sequence | None) -> Sequence | None: if v is None: return None if isinstance(v, torch.Tensor) and v.ndim == 0: return v if isinstance(v, torch.Tensor) and v.ndim >= 1: return v[permutation] return [v[i] for i in permutation] return {key: permute(val) for key, val in seq_dict.items()} def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors. Args: tensor (`torch.Tensor`): Input tensor of shape `(N,)`. Returns: `torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. """ if torch.isnan(tensor).all(): return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) return torch.min(tensor[~torch.isnan(tensor)]) def nanmax(tensor: torch.Tensor) -> torch.Tensor: """ Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors. Args: tensor (`torch.Tensor`): Input tensor of shape `(N,)`. Returns: `torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. """ if torch.isnan(tensor).all(): return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) return torch.max(tensor[~torch.isnan(tensor)]) def identity(x): """Do we really need docs for this?""" return x def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor | list[torch.Tensor]]: """ Splits `batch["pixel_values"]` into a list of tensors, one per sample, based on `batch["num_images"]`. For models with `image_grid_thw` (e.g. Qwen), the grid dimensions determine how many rows of `pixel_values` belong to each image. For models with `image_position_ids` instead (e.g. Gemma), `pixel_values` is indexed directly by image count. """ if "pixel_values" not in batch or "num_images" not in batch: return batch num_images = batch["num_images"] pixel_values = batch["pixel_values"] # [total, feature_dim] if "image_grid_thw" in batch: lengths = batch["image_grid_thw"].prod(-1).tolist() # [num_images] if sum(lengths) != pixel_values.size(0): raise ValueError( f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}" ) boundaries = [0, *accumulate(num_images)] image_grid_thw = batch["image_grid_thw"] # [total, 3] sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(num_images))] split_pixel_values = list(torch.split(pixel_values, sections, dim=0)) split_image_grid_thw = list(torch.split(image_grid_thw, num_images, dim=0)) return {**batch, "pixel_values": split_pixel_values, "image_grid_thw": split_image_grid_thw} if "image_position_ids" in batch: image_position_ids = batch["image_position_ids"] # [total] split_pixel_values = list(torch.split(pixel_values, num_images, dim=0)) split_image_position_ids = list(torch.split(image_position_ids, num_images, dim=0)) return {**batch, "pixel_values": split_pixel_values, "image_position_ids": split_image_position_ids} return batch def unsplit_pixel_values_by_grid(batch: dict[str, torch.Tensor | list[torch.Tensor]]) -> dict[str, torch.Tensor]: """ Opposite of `split_pixel_values_by_grid`. Merges a list of tensors in `batch["pixel_values"]` back into a single tensor along the first dimension. """ pixel_values = batch.get("pixel_values") if isinstance(pixel_values, list): merged = torch.cat(pixel_values, dim=0) batch = {**batch, "pixel_values": merged} image_grid_thw = batch.get("image_grid_thw") if isinstance(image_grid_thw, list): merged = torch.cat(image_grid_thw, dim=0) batch = {**batch, "image_grid_thw": merged} image_position_ids = batch.get("image_position_ids") if isinstance(image_position_ids, list): merged = torch.cat(image_position_ids, dim=0) batch = {**batch, "image_position_ids": merged} return batch TListOrMapping = TypeVar("TListOrMapping", list, Mapping) # This function is intentionally not used internally. It is provided as a utility for users whose datasets contain # `None` values inserted by tabular backends (e.g., Arrow/Parquet) for missing keys in nested structures. This # situation arises when loading datasets created before `datasets` v4.7.0 (which introduced the Json dtype), or when # datasets created after that version were saved without using the Json feature. In both cases, users can apply this # function via `dataset = dataset.with_transform(remove_none_values)` before training to strip the spurious `None` # values. See the migration guide for more details. def remove_none_values(example: TListOrMapping) -> TListOrMapping: """ Recursively removes entries with `None` values from a nested structure (list or dictionary). Args: example (`list` or `Mapping`): Input nested structure (list or dictionary) from which to remove `None`. Examples: ```python >>> dataset = dataset.with_transform(remove_none_values) ``` ```python >>> [ ... { ... "a": {"aa": None, "ab": 1}, ... "b": "my_string", ... } ... ] >>> remove_none_values(example) [{'a': {'ab': 1}, 'b': 'my_string'}] ``` """ if isinstance(example, list): return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example] elif isinstance(example, Mapping): return { key: remove_none_values(value) if isinstance(value, (dict, list)) else value for key, value in example.items() if value is not None } else: raise TypeError("Input must be a list or a dictionary.") def create_model_from_path( model_id: str, architecture: _BaseAutoModelClass | None = None, **kwargs ) -> PreTrainedModel: """ Create a model from a given path using the specified initialization arguments. Args: model_id (`str`): Path to the model. Can be either a local directory or a model identifier from the Hugging Face Hub. architecture (`_BaseAutoModelClass` or `None`, *optional*): Model architecture class to instantiate. The model is initialized using the `from_pretrained` method of this class. If `None`, the architecture will be inferred from the model's configuration. kwargs (`dict`): Initialization keyword arguments to pass to the model's `from_pretrained` method. When `'dtype'` is specified, it can be either a `torch.dtype` or one of the strings: `'bfloat16'`, `'float16'`, `'float32'`, or `'auto'`. If not explicitly set, `dtype` defaults to `'float32'`. Returns: [`~transformers.PreTrainedModel`]: The instantiated model. """ dtype = kwargs.get("dtype", "float32") if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: pass # dtype is already a torch.dtype or "auto" or None elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: kwargs["dtype"] = getattr(torch, dtype) else: raise ValueError( "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing " f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." ) kwargs["device_map"] = kwargs.get("device_map", "auto") if architecture is None: config = AutoConfig.from_pretrained(model_id) architecture = getattr(transformers, config.architectures[0]) model = architecture.from_pretrained(model_id, **kwargs) return model def hash_module(module: torch.nn.Module) -> str: h = hashlib.sha256() for _, tensor in sorted(module.state_dict().items()): tensor = tensor.cpu() h.update(str(tensor.dtype).encode()) if tensor.dtype in [torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2]: tensor = tensor.to(torch.float32) h.update(tensor.numpy().tobytes()) return h.hexdigest() def get_config_model_id(config: PretrainedConfig) -> str: """ Retrieve the model identifier from a given model configuration. Args: config ([`~transformers.PreTrainedConfig`]): Configuration from which to extract the model identifier. Returns: `str`: The model identifier associated with the model configuration. """ return getattr(config, "_name_or_path", "") @contextmanager def use_adapter(model: "PeftModel", adapter_name: str | None): """ Context manager to temporarily set and reset the active adapter in a PEFT model. Args: model ([`~peft.PeftModel`]): PEFT model to manage. adapter_name (`str` or `None`): Name of the adapter to set as active. If `None`, the context manager will disable all adapters. Example: ```python >>> from trl.trainer.utils import use_adapter >>> from peft import AutoPeftModelForCausalLM >>> import torch >>> model = AutoPeftModelForCausalLM.from_pretrained("path/to/model") >>> input_ids = torch.tensor([[1, 2, 3]]) >>> with use_adapter(model, "adapter_name"): ... outputs = model(input_ids) ``` """ if not is_peft_available(): raise ImportError( "You're trying to use a PEFT adapter but PEFT is not installed. Please install it with `pip install peft`." ) if adapter_name is None: with model.disable_adapter(): yield else: previous_adapter = model.active_adapter model.set_adapter(adapter_name) try: yield finally: model.set_adapter(previous_adapter) def start_event_loop_in_daemon( name: str | None = None, ) -> tuple[threading.Thread, asyncio.AbstractEventLoop, threading.Event]: """ This function creates a new daemon thread that runs the provided event loop. Args: name (`str`, *optional*): Name of the thread. If `None`, the default thread naming will be used. Returns: `threading.Thread`: The thread running the event loop. `asyncio.AbstractEventLoop`: The event loop being run in the thread. `threading.Event`: An event that is set when the loop is ready. """ loop = asyncio.new_event_loop() loop_ready_event = threading.Event() def run_loop(): asyncio.set_event_loop(loop) loop_ready_event.set() loop.run_forever() thread = threading.Thread(target=run_loop, name=name, daemon=True) thread.start() return thread, loop, loop_ready_event def shutdown_event_loop_in_daemon( thread: threading.Thread | None, loop: asyncio.AbstractEventLoop | None, ) -> None: """ Shutdown an asyncio event loop running in a separate thread. This function stops the event loop and waits for the associated thread to finish execution. Args: thread (`threading.Thread`): The thread running the event loop. loop (`asyncio.AbstractEventLoop`): The asyncio event loop to shut down. """ if loop is None or thread is None: return loop.call_soon_threadsafe(loop.stop) thread.join(timeout=5) class _ChunkedLogProbFunction(torch.autograd.Function): """Compute per-token log-probs and entropy without materializing [N, V] logits. Processes the lm_head in chunks and uses online logsumexp """ @staticmethod def forward( ctx, last_hidden: torch.Tensor, # [N, H] weight: torch.Tensor, # [V, H] targets: torch.Tensor, # [N] temperature: float, chunk_size: int, logit_scale: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: device = last_hidden.device N, _ = last_hidden.shape vocab, _ = weight.shape inv_t = logit_scale / temperature # NOTE(@aminediro): always acc in fp32 for stability max_old = torch.full((N,), float("-inf"), device=device, dtype=torch.float32) sum_exp = torch.zeros((N,), device=device, dtype=torch.float32) x_sum_exp = torch.zeros((N,), device=device, dtype=torch.float32) target_logit = torch.zeros((N,), device=device, dtype=torch.float32) # Pre-allocate reusable buffers to avoid per-chunk allocation mm_buf = torch.empty((N, chunk_size), device=device, dtype=last_hidden.dtype) logits_buf = torch.empty((N, chunk_size), device=device, dtype=torch.float32) for start in range(0, vocab, chunk_size): end = min(start + chunk_size, vocab) C = end - start # using fp16=True, the model's hidden states get cast to float16 by autocast, but the mm_buf is allocated # with last_hidden.dtype (float16) while w_chunk (the lm_head weights) is not auto casted w_chunk = weight[start:end].to(last_hidden.dtype) # [C, H] torch.mm(last_hidden, w_chunk.t(), out=mm_buf[:, :C]) logits_chunk = logits_buf[:, :C] logits_chunk.copy_(mm_buf[:, :C]) logits_chunk.mul_(inv_t) # [N, C] # Online logsumexp update chunk_max = logits_chunk.amax(dim=-1) # [N] max_new = torch.maximum(max_old, chunk_max) rescale = torch.exp(max_old - max_new) chunk_exp = torch.exp(logits_chunk - max_new.unsqueeze(-1)) # [N, C] sum_exp = sum_exp * rescale + chunk_exp.sum(dim=-1) x_sum_exp = x_sum_exp * rescale + (chunk_exp * logits_chunk).sum(dim=-1) max_old = max_new # Gather target logits for labels in this chunk in_chunk_cond = (targets >= start) & (targets < end) local_idx = torch.clamp(targets - start, 0, end - start - 1) # take the new logit if target_idx is in this chunk bounds else 0 target_logit += logits_chunk[torch.arange(N, device=device), local_idx] * in_chunk_cond log_z = max_old + torch.log(sum_exp) logprobs = target_logit - log_z entropy = log_z - x_sum_exp / sum_exp ctx.save_for_backward(last_hidden, weight, targets, log_z) ctx.temperature = temperature ctx.chunk_size = chunk_size ctx.logit_scale = logit_scale return logprobs, entropy @staticmethod def backward(ctx, grad_logprobs: torch.Tensor, grad_entropy: torch.Tensor): # type: ignore hidden, weight, labels, log_z = ctx.saved_tensors temperature: float = ctx.temperature chunk_size: int = ctx.chunk_size logit_scale: float = ctx.logit_scale inv_t = logit_scale / temperature N, _ = hidden.shape vocab = weight.shape[0] # NOTE(@aminediro): always acc in fp32 even if input is not grad_hidden = torch.zeros(hidden.shape, device=hidden.device, dtype=torch.float32) grad_weight = torch.zeros(weight.shape, device=weight.device, dtype=torch.float32) # Pre-allocate reusable buffers to avoid per-chunk allocation mm_buf = torch.empty((N, chunk_size), device=hidden.device, dtype=hidden.dtype) logits_buf = torch.empty((N, chunk_size), device=hidden.device, dtype=torch.float32) g = grad_logprobs.to(torch.float32) # [N] row_idx = torch.arange(N, device=hidden.device) for start in range(0, vocab, chunk_size): end = min(start + chunk_size, vocab) C = end - start w_chunk = weight[start:end] # [C, H] torch.mm(hidden, w_chunk.t(), out=mm_buf[:, :C]) logits_chunk = logits_buf[:, :C] logits_chunk.copy_(mm_buf[:, :C]) logits_chunk.mul_(inv_t) # [N, C] probs = torch.exp(logits_chunk - log_z.unsqueeze(-1)) # [N, C] # dL/d(logits) = g * (1_[label] - p) grad_logits = (-g).unsqueeze(-1) * probs # [N, C] in_chunk_cond = (labels >= start) & (labels < end) local_idx = torch.clamp(labels - start, 0, end - start - 1) # If label in chunk add g to grad else it stays the same grad_logits[row_idx, local_idx] += g * in_chunk_cond grad_logits = grad_logits * inv_t grad_hidden.add_(grad_logits @ w_chunk.float()) grad_weight[start:end].add_(grad_logits.t() @ hidden.float()) return grad_hidden.to(hidden.dtype), grad_weight.to(weight.dtype), None, None, None, None def patch_chunked_lm_head(model: torch.nn.Module, chunk_size: int, temperature: float) -> None: if getattr(model.config, "final_logit_softcapping", None) is not None: raise NotImplementedError( "The model uses `final_logit_softcapping` which is not yet supported. Please open an issue if you " "want your model to be supported." ) def _chunked_forward( self: torch.nn.Module, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, completion_mask: torch.Tensor | None = None, use_cache: bool = False, **kwargs, ) -> dict[str, torch.Tensor]: assert labels is not None, "requires labels to not be None for logprob computation" outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=use_cache, **kwargs) # NOTE(@aminediro): supporting Cohere2 models logit_scale = getattr(self.config, "logit_scale", 1.0) hidden_states = outputs.last_hidden_state # [B, S+1, H] # Shift: predict next token hidden_states = hidden_states[:, :-1, :] # [B, S-1, H] labels = labels[:, 1:] # [B, S-1] b, s, h = hidden_states.shape hidden_flat = hidden_states.reshape(b * s, h).contiguous() targets_flat = labels.reshape(b * s).contiguous() # Filter to completion tokens only to avoid expensive matmuls on prompt tokens and tool results valid_mask = None if completion_mask is not None: completion_mask = completion_mask[:, 1:] # same shift as labels valid_mask = completion_mask.bool().reshape(b * s) hidden_flat = hidden_flat[valid_mask] # [N_valid, H] targets_flat = targets_flat[valid_mask] # [N_valid] logprobs_valid, entropy_valid = _ChunkedLogProbFunction.apply( hidden_flat, self.lm_head.weight, targets_flat, temperature, chunk_size, logit_scale ) if valid_mask is not None: logprobs = torch.zeros(b * s, device=logprobs_valid.device, dtype=logprobs_valid.dtype) entropy = torch.zeros(b * s, device=entropy_valid.device, dtype=entropy_valid.dtype) logprobs[valid_mask] = logprobs_valid entropy[valid_mask] = entropy_valid else: logprobs = logprobs_valid entropy = entropy_valid return { "log_probs": logprobs.reshape(b, s), "entropy": entropy.reshape(b, s), } model.forward = types.MethodType(_chunked_forward, model)