|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import asyncio
|
| import hashlib
|
| import importlib.resources as pkg_resources
|
| import os
|
| import random
|
| import socket
|
| import threading
|
| import types
|
| from collections.abc import Mapping, Sequence, Sized
|
| from contextlib import contextmanager
|
| from importlib.metadata import version
|
| from itertools import accumulate
|
| from typing import TypeVar
|
|
|
| import numpy as np
|
| import pandas as pd
|
| import torch
|
| import torch.nn.functional as F
|
| import transformers
|
| from accelerate import PartialState, logging
|
| from huggingface_hub import ModelCard, ModelCardData
|
| from torch.utils.data import Sampler
|
| from transformers import (
|
| AutoConfig,
|
| BitsAndBytesConfig,
|
| PretrainedConfig,
|
| PreTrainedModel,
|
| is_comet_available,
|
| is_trackio_available,
|
| )
|
| from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
| from transformers.utils import (
|
| is_peft_available,
|
| is_rich_available,
|
| is_torch_xpu_available,
|
| )
|
|
|
| from ..trainer.model_config import ModelConfig
|
|
|
|
|
| if is_rich_available():
|
| from rich.console import Console
|
| from rich.panel import Panel
|
| from rich.table import Table
|
| from rich.text import Text
|
|
|
| if is_comet_available():
|
| import comet_ml
|
|
|
| if is_peft_available():
|
| from peft import LoraConfig, PeftConfig, PeftModel
|
|
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
|
|
| def _is_port_free(port: int, host: str = "127.0.0.1") -> bool:
|
| try:
|
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| s.bind((host, port))
|
| return True
|
| except OSError:
|
| return False
|
|
|
|
|
| def _find_free_port() -> int:
|
| candidates = (29500, 23456, 12355, 12345)
|
| for p in candidates:
|
| if _is_port_free(p):
|
| return p
|
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| s.bind(("", 0))
|
| return s.getsockname()[1]
|
|
|
|
|
| def ensure_master_addr_port(addr: str | None = None, port: int | None = None) -> None:
|
| """
|
| Ensure `MASTER_ADDR`/`MASTER_PORT` are set safely.
|
|
|
| - Respects existing environment variables.
|
| - Defaults `MASTER_ADDR` to localhost if unset.
|
| - Chooses a free TCP port if `MASTER_PORT` is unset to avoid collisions.
|
| - If `MASTER_PORT` is set to `"0"` or `"auto"`, it is resolved to a free port.
|
| """
|
| os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR") or addr or "localhost"
|
|
|
| env_port = os.environ.get("MASTER_PORT", "").strip().lower()
|
| if port is None and env_port not in {"", "0", "auto"}:
|
| try:
|
| port = int(env_port)
|
| except ValueError:
|
| pass
|
|
|
| os.environ["MASTER_PORT"] = str(_find_free_port() if port in (None, 0) else port)
|
|
|
|
|
| def pad(
|
| tensors: list[torch.Tensor],
|
| padding_value: int = 0,
|
| padding_side: str = "right",
|
| pad_to_multiple_of: int | None = None,
|
| ) -> torch.Tensor:
|
| """
|
| Pads a list of tensors to the same shape along the first dimension.
|
|
|
| Args:
|
| tensors (`list[torch.Tensor]`):
|
| List of input tensors to pad.
|
| padding_value (`int`):
|
| Value to use for padding. Default is 0.
|
| padding_side (`str`):
|
| Side on which to add padding. Must be 'left' or 'right'. Default is 'right'.
|
| pad_to_multiple_of (`int`, *optional*):
|
| If set will pad the sequence to a multiple of the provided value.
|
|
|
| Returns:
|
| `torch.Tensor`:
|
| A single tensor containing the padded tensors.
|
|
|
| Examples:
|
| ```python
|
| >>> import torch
|
|
|
| >>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
|
| tensor([[1, 2, 3],
|
| [4, 5, 0]])
|
|
|
| >>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])])
|
| tensor([[[1, 2],
|
| [3, 4]],
|
| [[5, 6],
|
| [0, 0]]])
|
| ```
|
| """
|
|
|
| output_shape = np.max([t.shape for t in tensors], 0).tolist()
|
|
|
|
|
| if pad_to_multiple_of is not None:
|
| remainder = output_shape[0] % pad_to_multiple_of
|
| if remainder != 0:
|
| output_shape[0] += pad_to_multiple_of - remainder
|
|
|
|
|
| output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)
|
|
|
| for i, t in enumerate(tensors):
|
| if padding_side == "left":
|
| seq_start = output_shape[0] - t.shape[0]
|
| elif padding_side == "right":
|
| seq_start = 0
|
| else:
|
| raise ValueError("padding_side must be 'left' or 'right'")
|
|
|
|
|
| seq_slice = slice(seq_start, seq_start + t.shape[0])
|
| slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
|
| output[i][slices] = t
|
|
|
| return output
|
|
|
|
|
| def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
| for module in model.modules():
|
| if isinstance(module, torch.nn.Dropout):
|
| module.p = 0
|
|
|
|
|
| def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConfig | None:
|
| if model_args.load_in_4bit:
|
| quantization_config = BitsAndBytesConfig(
|
| load_in_4bit=True,
|
| bnb_4bit_compute_dtype=model_args.dtype,
|
| bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
| bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
| bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
|
| )
|
| elif model_args.load_in_8bit:
|
| quantization_config = BitsAndBytesConfig(
|
| load_in_8bit=True,
|
| )
|
| else:
|
| quantization_config = None
|
|
|
| return quantization_config
|
|
|
|
|
| def get_kbit_device_map() -> dict[str, int] | None:
|
| if torch.cuda.is_available() or is_torch_xpu_available():
|
| return {"": PartialState().local_process_index}
|
| else:
|
| return None
|
|
|
|
|
| def get_peft_config(model_args: ModelConfig) -> "PeftConfig | None":
|
| if model_args.use_peft is False:
|
| return None
|
|
|
| if not is_peft_available():
|
| raise ValueError(
|
| "You need to have PEFT library installed in your environment, make sure to install `peft`. "
|
| "Make sure to run `pip install -U peft`."
|
| )
|
|
|
| peft_config = LoraConfig(
|
| task_type=model_args.lora_task_type,
|
| r=model_args.lora_r,
|
| target_modules=model_args.lora_target_modules,
|
| target_parameters=model_args.lora_target_parameters,
|
| lora_alpha=model_args.lora_alpha,
|
| lora_dropout=model_args.lora_dropout,
|
| bias="none",
|
| use_rslora=model_args.use_rslora,
|
| use_dora=model_args.use_dora,
|
| modules_to_save=model_args.lora_modules_to_save,
|
| )
|
|
|
| return peft_config
|
|
|
|
|
| def generate_model_card(
|
| base_model: str | None,
|
| model_name: str,
|
| hub_model_id: str,
|
| dataset_name: str | None,
|
| tags: list[str],
|
| wandb_url: str | None,
|
| trackio_url: str | None,
|
| trainer_name: str,
|
| trainer_citation: str | None = None,
|
| template_file: str | None = None,
|
| paper_title: str | None = None,
|
| paper_id: str | None = None,
|
| comet_url: str | None = None,
|
| ) -> ModelCard:
|
| """
|
| Generate a [`~huggingface_hub.ModelCard`] from a template.
|
|
|
| Args:
|
| base_model (`str` or `None`):
|
| Base model name.
|
| model_name (`str`):
|
| Model name.
|
| hub_model_id (`str`):
|
| Hub model ID as `username/model_id`.
|
| dataset_name (`str` or `None`):
|
| Dataset name.
|
| tags (`list[str]`):
|
| Tags.
|
| wandb_url (`str` or `None`):
|
| Weights & Biases run URL.
|
| trackio_url (`str` or `None`):
|
| Trackio Space URL.
|
| comet_url (`str` or `None`):
|
| Comet experiment URL.
|
| trainer_name (`str`):
|
| Trainer name.
|
| trainer_citation (`str` or `None`, defaults to `None`):
|
| Trainer citation as a BibTeX entry.
|
| template_file (`str` *optional*):
|
| Template file name located in the `trl/templates` directory. Defaults to `lm_model_card.md`.
|
| paper_title (`str` or `None`, defaults to `None`):
|
| Paper title.
|
| paper_id (`str` or `None`, defaults to `None`):
|
| ArXiv paper ID as `YYMM.NNNNN`.
|
|
|
| Returns:
|
| [`~huggingface_hub.ModelCard`]:
|
| A ModelCard object.
|
| """
|
| card_data = ModelCardData(
|
| base_model=base_model,
|
| datasets=dataset_name,
|
| library_name="transformers",
|
| licence="license",
|
| model_name=model_name,
|
| tags=["generated_from_trainer", *tags],
|
| )
|
| template_file = template_file or "lm_model_card.md"
|
| card = ModelCard.from_template(
|
| card_data,
|
| template_path=str(pkg_resources.files("trl").joinpath(f"templates/{template_file}")),
|
| base_model=base_model,
|
| model_name=model_name,
|
| hub_model_id=hub_model_id,
|
| dataset_name=dataset_name,
|
| wandb_url=wandb_url,
|
| trackio_url=trackio_url,
|
| comet_url=comet_url,
|
| trainer_name=trainer_name,
|
| trainer_citation=trainer_citation,
|
| paper_title=paper_title,
|
| paper_id=paper_id,
|
| trl_version=version("trl"),
|
| transformers_version=version("transformers"),
|
| pytorch_version=version("torch"),
|
| datasets_version=version("datasets"),
|
| tokenizers_version=version("tokenizers"),
|
| )
|
| return card
|
|
|
|
|
| def get_comet_experiment_url() -> str | None:
|
| """
|
| If Comet integration is enabled, return the URL of the current Comet experiment; otherwise, return `None`.
|
| """
|
| if not is_comet_available():
|
| return None
|
|
|
| if comet_ml.get_running_experiment() is not None:
|
| return comet_ml.get_running_experiment().url
|
|
|
| return None
|
|
|
|
|
| def get_trackio_space_url() -> str | None:
|
| """
|
| If Trackio integration is enabled, return the URL of the current Trackio Space; otherwise, return `None`.
|
| """
|
| if not is_trackio_available():
|
| return None
|
|
|
| from trackio import context_vars
|
|
|
| run = context_vars.current_run.get()
|
| if run is None:
|
| return None
|
| space_id = run._space_id
|
| if space_id is None:
|
| return None
|
| space_id = space_id.replace("/", "-")
|
| project = run.project
|
| name = run.name
|
| return f"https://{space_id}.hf.space?project={project}&runs={name}&sidebar=collapsed"
|
|
|
|
|
| def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None:
|
| """
|
| If Comet integration is enabled logs a table to the Comet experiment if it is currently running.
|
|
|
| Args:
|
| name (`str`):
|
| Table name.
|
| table (`pandas.DataFrame`):
|
| The Pandas DataFrame containing the table to log.
|
| """
|
| if not is_comet_available():
|
| raise ModuleNotFoundError("The comet-ml is not installed. Please install it first: pip install comet-ml")
|
|
|
| experiment = comet_ml.get_running_experiment()
|
| if experiment is not None:
|
| experiment.log_table(tabular_data=table, filename=name)
|
|
|
|
|
| def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]:
|
| """
|
| Shift non-zero elements in the mask and corresponding tensors to the left.
|
|
|
| This function operates on a binary mask and any number of additional tensors with the same dimensions as the mask.
|
| For each row, non-zero values are shifted to the leftmost positions. Then, columns that contain only zeros across
|
| all rows are truncated from the mask and tensors. Visually, this operation can be represented as follows:
|
|
|
| ```
|
| [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
| [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
| ```
|
|
|
| Args:
|
| mask (`torch.Tensor`):
|
| 2D tensor (binary mask) with shape `(N, M)`.
|
| *tensors (`torch.Tensor`):
|
| One or more 2D tensors with the same shape as `mask`. These tensors will be processed alongside `mask`,
|
| with non-zero values shifted and excess zero columns truncated in the same manner.
|
|
|
| Returns:
|
| `torch.Tensor`:
|
| Updated binary mask with non-zero values flushed to the left and trailing zero columns removed.
|
| `*torch.Tensor`
|
| Updated tensors, processed in the same way as the mask.
|
|
|
| Example:
|
| ```python
|
| >>> mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]])
|
| >>> tensor = torch.tensor([[9, 9, 2, 3, 4], [9, 5, 6, 9, 9]])
|
| >>> new_mask, new_tensor = flush_left(mask, tensor)
|
| >>> print(new_mask)
|
| tensor([[1, 1, 1],
|
| [1, 1, 0]])
|
|
|
| >>> print(new_tensor)
|
| tensor([[2, 3, 4],
|
| [5, 6, 0]])
|
| ```
|
| """
|
| _, M = mask.shape
|
|
|
|
|
| mask_copy = mask.clone()
|
| tensors = [t.clone() for t in tensors]
|
|
|
|
|
| first_non_zero = mask_copy.argmax(dim=1)
|
| pos = torch.arange(M, device=mask_copy.device).unsqueeze(0)
|
| idx_roll = (pos + first_non_zero.unsqueeze(1)) % M
|
| mask_roll = mask_copy.gather(1, idx_roll)
|
| rolled_tensors = [t.gather(1, idx_roll) for t in tensors]
|
|
|
|
|
| col_sums = mask_roll.sum(dim=0)
|
| empty_cols = col_sums == 0
|
| first_empty_col = int(empty_cols.to(torch.int8).argmax()) if empty_cols.any() else M
|
| flushed_mask = mask_roll[:, :first_empty_col]
|
| flushed_tensors = [t[:, :first_empty_col] for t in rolled_tensors]
|
|
|
| if not flushed_tensors:
|
| return flushed_mask
|
| return flushed_mask, *flushed_tensors
|
|
|
|
|
| def selective_log_softmax(logits, index) -> torch.Tensor:
|
| """
|
| A memory-efficient implementation of the common `log_softmax -> gather` operation.
|
|
|
| This function is equivalent to the following naive implementation:
|
| ```python
|
| # for index with shape (...):
|
| logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
|
| # for index with shape (..., K):
|
| logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index)
|
| ```
|
|
|
| Args:
|
| logits (`torch.Tensor`):
|
| Logits tensor of shape `(..., num_classes)`.
|
| index (`torch.Tensor`):
|
| Index tensor of shape `(..., K)` or `(...)`, specifying the positions to gather from the log-softmax
|
| output. When the last case is used, `K` log-probabilities are gathered per position (e.g. for top-K)
|
|
|
| Returns:
|
| `torch.Tensor`:
|
| Gathered log probabilities with the same shape as `index`.
|
| """
|
| squeeze = index.ndim == logits.ndim - 1
|
| if squeeze:
|
| index = index.unsqueeze(-1)
|
|
|
| if logits.dtype in [torch.float32, torch.float64]:
|
| selected_logits = torch.gather(logits, dim=-1, index=index)
|
|
|
| logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| per_token_logps = selected_logits - logsumexp_values.unsqueeze(-1)
|
| else:
|
|
|
| per_token_logps = []
|
| for row_logits, row_labels in zip(logits, index, strict=True):
|
| row_logps = F.log_softmax(row_logits, dim=-1)
|
| row_per_token_logps = row_logps.gather(dim=-1, index=row_labels)
|
| per_token_logps.append(row_per_token_logps)
|
| per_token_logps = torch.stack(per_token_logps)
|
|
|
| if squeeze:
|
| per_token_logps = per_token_logps.squeeze(-1)
|
|
|
| return per_token_logps
|
|
|
|
|
| def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor:
|
| """
|
| Compute the Shannon entropy (in nats) for each row of *logits* in a memory-efficient way.
|
|
|
| Instead of materializing the full softmax for all rows at once, the logits are flattened to shape (N, num_classes),
|
| where N is the product of all leading dimensions. Computation is then performed in chunks of size `chunk_size`
|
| along this flattened dimension, reducing peak memory usage. The result is reshaped back to match the input's
|
| leading dimensions.
|
|
|
| Args:
|
| logits (`torch.Tensor`):
|
| Logits tensor of shape `(..., num_classes)`. Entropy is taken along the last axis; all leading dimensions
|
| are preserved in the output.
|
| chunk_size (`int`, *optional*, defaults to `128`):
|
| Number of rows from the flattened logits to process per iteration. Smaller values reduce memory usage at
|
| the cost of more iterations.
|
|
|
| Returns:
|
| `torch.Tensor`:
|
| Entropy values with shape `logits.shape[:-1]`.
|
| """
|
| original_shape = logits.shape[:-1]
|
| num_classes = logits.shape[-1]
|
|
|
|
|
| flat_logits = logits.reshape(-1, num_classes)
|
|
|
| entropies = []
|
| for chunk in flat_logits.split(chunk_size, dim=0):
|
| logps = F.log_softmax(chunk, dim=-1)
|
| chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
|
| entropies.append(chunk_entropy)
|
|
|
| entropies = torch.cat(entropies, dim=0)
|
| return entropies.reshape(original_shape)
|
|
|
|
|
| def print_prompt_completions_sample(
|
| prompts: list,
|
| completions: list,
|
| rewards: dict[str, list[float]],
|
| advantages: list[float],
|
| step: int,
|
| num_samples: int = None,
|
| extra: dict[str, list] | None = None,
|
| ) -> None:
|
| """
|
| Print out a sample of model completions to the console with multiple reward metrics.
|
|
|
| This function creates a nicely formatted table showing prompt-completion pairs, useful for monitoring model outputs
|
| during training. It requires the `rich` library to be installed.
|
|
|
| Args:
|
| prompts (`list`):
|
| List of prompts. Can be either strings or lists of messages.
|
| completions (`list`):
|
| List of completions corresponding to the prompts. Can be either strings or lists of messages.
|
| rewards (`dict[str, list[float]]`):
|
| Dictionary where keys are reward names and values are lists of rewards.
|
| advantages (`list[float]`):
|
| List of advantages corresponding to the prompts and completions.
|
| step (`int`):
|
| Current training step number, used in the output title.
|
| num_samples (`int`, *optional*):
|
| Number of random samples to display. If `None` (default), all items will be displayed.
|
| extra (`dict[str, list]`, *optional*):
|
| Additional columns to display after the advantage column. Keys are column names and values are lists of
|
| per-completion data (strings or any value convertible to string). Typically populated via `log_extra` in
|
| reward functions. If `None` (default), no extra columns are shown.
|
|
|
| Example:
|
| ```python
|
| >>> from trl.trainer.utils import print_prompt_completions_sample
|
|
|
| >>> prompts = ["The sky is", "The sun is"]
|
| >>> completions = [" blue.", " in the sky."]
|
| >>> rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
|
| >>> advantages = [0.987, 0.654]
|
| >>> extra = {"source": ["dataset_A", "dataset_B"]}
|
| >>> print_prompt_completions_sample(prompts, completions, rewards, advantages, 42, extra=extra)
|
| ╭────────────────────────────────── Step 42 ───────────────────────────────────╮
|
| │ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┓ │
|
| │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ source ┃ │
|
| │ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━┩ │
|
| │ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ dataset_A │ │
|
| │ ├────────────┼──────────────┼─────────────┼────────┼───────────┼───────────┤ │
|
| │ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ dataset_B │ │
|
| │ └────────────┴──────────────┴─────────────┴────────┴───────────┴───────────┘ │
|
| ╰──────────────────────────────────────────────────────────────────────────────╯
|
| ```
|
| """
|
| if not is_rich_available():
|
| raise ImportError(
|
| "The function `print_prompt_completions_sample` requires the `rich` library. Please install it with "
|
| "`pip install rich`."
|
| )
|
| console = Console()
|
| table = Table(show_header=True, header_style="bold white", expand=True)
|
|
|
| extra = extra or {}
|
|
|
|
|
| table.add_column("Prompt", style="bright_yellow")
|
| table.add_column("Completion", style="bright_green")
|
| for reward_name in rewards.keys():
|
| table.add_column(reward_name, style="bold cyan", justify="right")
|
| table.add_column("Advantage", style="bold magenta", justify="right")
|
| for extra_name in extra.keys():
|
| table.add_column(extra_name, style="bright_white")
|
|
|
| def format_entry(entry) -> Text:
|
| t = Text()
|
| if isinstance(entry, list) and all(isinstance(m, dict) for m in entry):
|
| for j, msg in enumerate(entry):
|
| role = msg.get("role", "")
|
| if "content" in msg or "reasoning_content" in msg or "thinking" in msg:
|
|
|
| t.append(f"{role.upper()}\n", style="bold red")
|
| reasoning = msg.get("reasoning_content") or msg.get("thinking")
|
| if reasoning:
|
| t.append(reasoning, style="italic dim white")
|
| t.append("\n")
|
| if "content" in msg:
|
| t.append(msg["content"])
|
| elif "name" in msg and "args" in msg:
|
|
|
| t.append(f"{role.upper()}\n", style="bold red")
|
| t.append(f"{msg['name']}({msg['args']})")
|
| else:
|
|
|
| t.append(str(msg))
|
| if j < len(entry) - 1:
|
| t.append("\n\n")
|
| else:
|
| t.append(str(entry))
|
| return t
|
|
|
|
|
| if num_samples is not None:
|
| if num_samples >= len(prompts):
|
| num_samples = None
|
| elif num_samples <= 0:
|
| return
|
|
|
|
|
| if num_samples is not None:
|
| indices = random.sample(range(len(prompts)), num_samples)
|
| prompts = [prompts[i] for i in indices]
|
| completions = [completions[i] for i in indices]
|
| rewards = {key: [val[i] for i in indices] for key, val in rewards.items()}
|
| advantages = [advantages[i] for i in indices]
|
| extra = {key: [val[i] for i in indices] for key, val in extra.items()}
|
|
|
| for i in range(len(prompts)):
|
| reward_values = [f"{rewards[key][i]:.2f}" for key in rewards.keys()]
|
| extra_values = [format_entry(extra[key][i]) for key in extra.keys()]
|
| table.add_row(
|
| format_entry(prompts[i]),
|
| format_entry(completions[i]),
|
| *reward_values,
|
| f"{advantages[i]:.2f}",
|
| *extra_values,
|
| )
|
| table.add_section()
|
|
|
| panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white")
|
| console.print(panel)
|
|
|
|
|
| class RepeatSampler(Sampler):
|
| """
|
| Sampler that repeats the indices of a dataset in a structured manner.
|
|
|
| Args:
|
| data_source (`Sized`):
|
| Dataset to sample from.
|
| mini_repeat_count (`int`):
|
| Number of times to repeat each index per batch.
|
| batch_size (`int`, *optional*, defaults to `1`):
|
| Number of unique indices per batch.
|
| repeat_count (`int`, *optional*, defaults to `1`):
|
| Number of times to repeat the full sampling process.
|
| shuffle (`bool`, *optional*, defaults to `True`):
|
| Whether to shuffle the dataset.
|
| seed (`int`, *optional*):
|
| Random seed for reproducibility (only affects this sampler).
|
|
|
| Example:
|
| ```python
|
| >>> sampler = RepeatSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4)
|
| >>> list(sampler)
|
| [4, 4, 3, 3, 0, 0,
|
| 4, 4, 3, 3, 0, 0,
|
| 4, 4, 3, 3, 0, 0,
|
| 4, 4, 3, 3, 0, 0,
|
| 1, 1, 2, 2, 6, 6,
|
| 1, 1, 2, 2, 6, 6,
|
| 1, 1, 2, 2, 6, 6,
|
| 1, 1, 2, 2, 6, 6]
|
| ```
|
|
|
| ```txt
|
| mini_repeat_count = 3
|
| - - -
|
| [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
|
| 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
|
| 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, |
|
| repeat_count = 2
|
| 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
|
| 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
|
| 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] |
|
| --------- --------- --------- ---------
|
| --------- --------- --------- ---------
|
| --------- --------- --------- ---------
|
| batch_size = 12
|
| ```
|
| """
|
|
|
| def __init__(
|
| self,
|
| data_source: Sized,
|
| mini_repeat_count: int,
|
| batch_size: int = 1,
|
| repeat_count: int = 1,
|
| shuffle: bool = True,
|
| seed: int | None = None,
|
| ):
|
| self.data_source = data_source
|
| self.mini_repeat_count = mini_repeat_count
|
| self.batch_size = batch_size
|
| self.repeat_count = repeat_count
|
| self.num_samples = len(data_source)
|
| self.shuffle = shuffle
|
| self.seed = seed
|
|
|
| if shuffle:
|
| self.generator = torch.Generator()
|
| if seed is not None:
|
| self.generator.manual_seed(seed)
|
|
|
| def __iter__(self):
|
| if self.shuffle:
|
|
|
| indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
| else:
|
| indexes = list(range(self.num_samples))
|
|
|
|
|
|
|
| indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
|
|
|
|
|
|
|
| indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
|
|
|
| for chunk in indexes:
|
| for _ in range(self.repeat_count):
|
| for index in chunk:
|
| for _ in range(self.mini_repeat_count):
|
| yield index
|
|
|
| def __len__(self) -> int:
|
| return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count
|
|
|
|
|
|
|
| def nanstd(tensor: torch.Tensor, dim: int | tuple[int, ...] | None = None, keepdim: bool = False) -> torch.Tensor:
|
| """
|
| Compute the standard deviation of a tensor, ignoring NaNs.
|
|
|
| Args:
|
| tensor (`torch.Tensor`):
|
| Input tensor.
|
| dim (`int` or `tuple[int, ...]`, *optional*):
|
| Dimension(s) to reduce. Defaults to all dimensions.
|
| keepdim (`bool`, *optional*, defaults to `False`):
|
| Whether to keep reduced dimensions.
|
|
|
| Returns:
|
| `torch.Tensor`:
|
| Standard deviation of the tensor, ignoring NaNs.
|
| """
|
|
|
| mean = torch.nanmean(tensor, dim=dim, keepdim=True)
|
| variance = torch.nanmean((tensor - mean) ** 2, dim=dim, keepdim=True)
|
| count = torch.sum(~torch.isnan(tensor), dim=dim, keepdim=True)
|
| correction = count / (count - 1)
|
| correction = torch.where(count > 1, correction, torch.full_like(correction, float("nan")))
|
| variance *= correction
|
| std = torch.sqrt(variance)
|
| if keepdim:
|
| return std
|
| if dim is None:
|
| return std.squeeze()
|
| if isinstance(dim, int):
|
| return std.squeeze(dim)
|
| dims = [(d if d >= 0 else d + std.ndim) for d in dim]
|
| for d in sorted(dims, reverse=True):
|
| std = std.squeeze(d)
|
| return std
|
|
|
|
|
| def split_tensor_dict(
|
| tensor_dict: dict[str, torch.Tensor | None], num_chunks: int
|
| ) -> list[dict[str, torch.Tensor | None]]:
|
| """
|
| Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts.
|
|
|
| Example:
|
| ```python
|
| >>> x = torch.arange(12).reshape(6, 2)
|
| >>> y = torch.arange(6).reshape(6, 1)
|
| >>> tensor_dict = {"x": x, "y": y}
|
| >>> split_tensor_dict(tensor_dict, 3)
|
| [
|
| {"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])},
|
| {"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])},
|
| {"x": tensor([[ 8, 9], [10, 11]]), "y": tensor([[4], [5]])}
|
| ]
|
| ```
|
| """
|
| first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
|
| chunk_size = first_tensor.shape[0] // num_chunks
|
| chunks = []
|
| for i in range(num_chunks):
|
| chunk_dict = {}
|
| for key, tensor in tensor_dict.items():
|
| if tensor is not None and (isinstance(tensor, list) or tensor.ndim > 0):
|
| chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size]
|
| elif tensor is not None and tensor.ndim == 0:
|
| chunk_dict[key] = tensor
|
| else:
|
| chunk_dict[key] = None
|
| chunks.append(chunk_dict)
|
| return chunks
|
|
|
|
|
| def shuffle_sequence_dict(seq_dict: dict[str, Sequence | None]) -> dict[str, Sequence | None]:
|
| """
|
| Shuffles all sequence-like values in a dictionary along the first dimension in unison.
|
|
|
| Example:
|
| ```python
|
| >>> x = torch.arange(6).reshape(3, 2)
|
| >>> y = ["a", "b", "c"]
|
| >>> seq_dict = {"x": x, "y": y}
|
| >>> shuffle_sequence_dict(seq_dict)
|
| {'x': tensor([[2, 3],
|
| [0, 1],
|
| [4, 5]]),
|
| 'y': ['b', 'a', 'c']}
|
| ```
|
| """
|
|
|
| batch_size = len(next(v for v in seq_dict.values() if v is not None))
|
| permutation = torch.randperm(batch_size)
|
|
|
| def permute(v: Sequence | None) -> Sequence | None:
|
| if v is None:
|
| return None
|
| if isinstance(v, torch.Tensor) and v.ndim == 0:
|
| return v
|
| if isinstance(v, torch.Tensor) and v.ndim >= 1:
|
| return v[permutation]
|
| return [v[i] for i in permutation]
|
|
|
| return {key: permute(val) for key, val in seq_dict.items()}
|
|
|
|
|
| def nanmin(tensor: torch.Tensor) -> torch.Tensor:
|
| """
|
| Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
|
|
|
| Args:
|
| tensor (`torch.Tensor`): Input tensor of shape `(N,)`.
|
|
|
| Returns:
|
| `torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
|
| """
|
| if torch.isnan(tensor).all():
|
| return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
|
| return torch.min(tensor[~torch.isnan(tensor)])
|
|
|
|
|
| def nanmax(tensor: torch.Tensor) -> torch.Tensor:
|
| """
|
| Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
|
|
|
| Args:
|
| tensor (`torch.Tensor`): Input tensor of shape `(N,)`.
|
|
|
| Returns:
|
| `torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
|
| """
|
| if torch.isnan(tensor).all():
|
| return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
|
| return torch.max(tensor[~torch.isnan(tensor)])
|
|
|
|
|
| def identity(x):
|
| """Do we really need docs for this?"""
|
| return x
|
|
|
|
|
| def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor | list[torch.Tensor]]:
|
| """
|
| Splits `batch["pixel_values"]` into a list of tensors, one per sample, based on `batch["num_images"]`.
|
|
|
| For models with `image_grid_thw` (e.g. Qwen), the grid dimensions determine how many rows of `pixel_values` belong
|
| to each image. For models with `image_position_ids` instead (e.g. Gemma), `pixel_values` is indexed directly by
|
| image count.
|
| """
|
| if "pixel_values" not in batch or "num_images" not in batch:
|
| return batch
|
|
|
| num_images = batch["num_images"]
|
| pixel_values = batch["pixel_values"]
|
|
|
| if "image_grid_thw" in batch:
|
| lengths = batch["image_grid_thw"].prod(-1).tolist()
|
| if sum(lengths) != pixel_values.size(0):
|
| raise ValueError(
|
| f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}"
|
| )
|
|
|
| boundaries = [0, *accumulate(num_images)]
|
| image_grid_thw = batch["image_grid_thw"]
|
| sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(num_images))]
|
| split_pixel_values = list(torch.split(pixel_values, sections, dim=0))
|
| split_image_grid_thw = list(torch.split(image_grid_thw, num_images, dim=0))
|
| return {**batch, "pixel_values": split_pixel_values, "image_grid_thw": split_image_grid_thw}
|
|
|
| if "image_position_ids" in batch:
|
| image_position_ids = batch["image_position_ids"]
|
| split_pixel_values = list(torch.split(pixel_values, num_images, dim=0))
|
| split_image_position_ids = list(torch.split(image_position_ids, num_images, dim=0))
|
| return {**batch, "pixel_values": split_pixel_values, "image_position_ids": split_image_position_ids}
|
|
|
| return batch
|
|
|
|
|
| def unsplit_pixel_values_by_grid(batch: dict[str, torch.Tensor | list[torch.Tensor]]) -> dict[str, torch.Tensor]:
|
| """
|
| Opposite of `split_pixel_values_by_grid`. Merges a list of tensors in `batch["pixel_values"]` back into a single
|
| tensor along the first dimension.
|
| """
|
| pixel_values = batch.get("pixel_values")
|
| if isinstance(pixel_values, list):
|
| merged = torch.cat(pixel_values, dim=0)
|
| batch = {**batch, "pixel_values": merged}
|
|
|
| image_grid_thw = batch.get("image_grid_thw")
|
| if isinstance(image_grid_thw, list):
|
| merged = torch.cat(image_grid_thw, dim=0)
|
| batch = {**batch, "image_grid_thw": merged}
|
|
|
| image_position_ids = batch.get("image_position_ids")
|
| if isinstance(image_position_ids, list):
|
| merged = torch.cat(image_position_ids, dim=0)
|
| batch = {**batch, "image_position_ids": merged}
|
|
|
| return batch
|
|
|
|
|
| TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def remove_none_values(example: TListOrMapping) -> TListOrMapping:
|
| """
|
| Recursively removes entries with `None` values from a nested structure (list or dictionary).
|
|
|
| Args:
|
| example (`list` or `Mapping`):
|
| Input nested structure (list or dictionary) from which to remove `None`.
|
|
|
| Examples:
|
| ```python
|
| >>> dataset = dataset.with_transform(remove_none_values)
|
| ```
|
| ```python
|
| >>> [
|
| ... {
|
| ... "a": {"aa": None, "ab": 1},
|
| ... "b": "my_string",
|
| ... }
|
| ... ]
|
| >>> remove_none_values(example)
|
| [{'a': {'ab': 1}, 'b': 'my_string'}]
|
| ```
|
| """
|
| if isinstance(example, list):
|
| return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example]
|
| elif isinstance(example, Mapping):
|
| return {
|
| key: remove_none_values(value) if isinstance(value, (dict, list)) else value
|
| for key, value in example.items()
|
| if value is not None
|
| }
|
| else:
|
| raise TypeError("Input must be a list or a dictionary.")
|
|
|
|
|
| def create_model_from_path(
|
| model_id: str, architecture: _BaseAutoModelClass | None = None, **kwargs
|
| ) -> PreTrainedModel:
|
| """
|
| Create a model from a given path using the specified initialization arguments.
|
|
|
| Args:
|
| model_id (`str`):
|
| Path to the model. Can be either a local directory or a model identifier from the Hugging Face Hub.
|
| architecture (`_BaseAutoModelClass` or `None`, *optional*):
|
| Model architecture class to instantiate. The model is initialized using the `from_pretrained` method of
|
| this class. If `None`, the architecture will be inferred from the model's configuration.
|
| kwargs (`dict`):
|
| Initialization keyword arguments to pass to the model's `from_pretrained` method. When `'dtype'` is
|
| specified, it can be either a `torch.dtype` or one of the strings: `'bfloat16'`, `'float16'`, `'float32'`,
|
| or `'auto'`. If not explicitly set, `dtype` defaults to `'float32'`.
|
|
|
| Returns:
|
| [`~transformers.PreTrainedModel`]:
|
| The instantiated model.
|
| """
|
| dtype = kwargs.get("dtype", "float32")
|
| if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
|
| pass
|
| elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
|
| kwargs["dtype"] = getattr(torch, dtype)
|
| else:
|
| raise ValueError(
|
| "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing "
|
| f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
|
| )
|
| kwargs["device_map"] = kwargs.get("device_map", "auto")
|
| if architecture is None:
|
| config = AutoConfig.from_pretrained(model_id)
|
| architecture = getattr(transformers, config.architectures[0])
|
| model = architecture.from_pretrained(model_id, **kwargs)
|
| return model
|
|
|
|
|
| def hash_module(module: torch.nn.Module) -> str:
|
| h = hashlib.sha256()
|
| for _, tensor in sorted(module.state_dict().items()):
|
| tensor = tensor.cpu()
|
| h.update(str(tensor.dtype).encode())
|
| if tensor.dtype in [torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2]:
|
| tensor = tensor.to(torch.float32)
|
| h.update(tensor.numpy().tobytes())
|
| return h.hexdigest()
|
|
|
|
|
| def get_config_model_id(config: PretrainedConfig) -> str:
|
| """
|
| Retrieve the model identifier from a given model configuration.
|
|
|
| Args:
|
| config ([`~transformers.PreTrainedConfig`]):
|
| Configuration from which to extract the model identifier.
|
|
|
| Returns:
|
| `str`:
|
| The model identifier associated with the model configuration.
|
| """
|
| return getattr(config, "_name_or_path", "")
|
|
|
|
|
| @contextmanager
|
| def use_adapter(model: "PeftModel", adapter_name: str | None):
|
| """
|
| Context manager to temporarily set and reset the active adapter in a PEFT model.
|
|
|
| Args:
|
| model ([`~peft.PeftModel`]):
|
| PEFT model to manage.
|
| adapter_name (`str` or `None`):
|
| Name of the adapter to set as active. If `None`, the context manager will disable all adapters.
|
|
|
| Example:
|
| ```python
|
| >>> from trl.trainer.utils import use_adapter
|
| >>> from peft import AutoPeftModelForCausalLM
|
| >>> import torch
|
|
|
| >>> model = AutoPeftModelForCausalLM.from_pretrained("path/to/model")
|
| >>> input_ids = torch.tensor([[1, 2, 3]])
|
| >>> with use_adapter(model, "adapter_name"):
|
| ... outputs = model(input_ids)
|
| ```
|
| """
|
|
|
| if not is_peft_available():
|
| raise ImportError(
|
| "You're trying to use a PEFT adapter but PEFT is not installed. Please install it with `pip install peft`."
|
| )
|
| if adapter_name is None:
|
| with model.disable_adapter():
|
| yield
|
| else:
|
| previous_adapter = model.active_adapter
|
| model.set_adapter(adapter_name)
|
| try:
|
| yield
|
| finally:
|
| model.set_adapter(previous_adapter)
|
|
|
|
|
| def start_event_loop_in_daemon(
|
| name: str | None = None,
|
| ) -> tuple[threading.Thread, asyncio.AbstractEventLoop, threading.Event]:
|
| """
|
| This function creates a new daemon thread that runs the provided event loop.
|
|
|
| Args:
|
| name (`str`, *optional*):
|
| Name of the thread. If `None`, the default thread naming will be used.
|
|
|
| Returns:
|
| `threading.Thread`:
|
| The thread running the event loop.
|
| `asyncio.AbstractEventLoop`:
|
| The event loop being run in the thread.
|
| `threading.Event`:
|
| An event that is set when the loop is ready.
|
| """
|
| loop = asyncio.new_event_loop()
|
| loop_ready_event = threading.Event()
|
|
|
| def run_loop():
|
| asyncio.set_event_loop(loop)
|
| loop_ready_event.set()
|
| loop.run_forever()
|
|
|
| thread = threading.Thread(target=run_loop, name=name, daemon=True)
|
| thread.start()
|
| return thread, loop, loop_ready_event
|
|
|
|
|
| def shutdown_event_loop_in_daemon(
|
| thread: threading.Thread | None,
|
| loop: asyncio.AbstractEventLoop | None,
|
| ) -> None:
|
| """
|
| Shutdown an asyncio event loop running in a separate thread.
|
|
|
| This function stops the event loop and waits for the associated thread to finish execution.
|
|
|
| Args:
|
| thread (`threading.Thread`):
|
| The thread running the event loop.
|
| loop (`asyncio.AbstractEventLoop`):
|
| The asyncio event loop to shut down.
|
| """
|
| if loop is None or thread is None:
|
| return
|
| loop.call_soon_threadsafe(loop.stop)
|
| thread.join(timeout=5)
|
|
|
|
|
| class _ChunkedLogProbFunction(torch.autograd.Function):
|
| """Compute per-token log-probs and entropy without materializing [N, V] logits.
|
|
|
| Processes the lm_head in chunks and uses online logsumexp
|
| """
|
|
|
| @staticmethod
|
| def forward(
|
| ctx,
|
| last_hidden: torch.Tensor,
|
| weight: torch.Tensor,
|
| targets: torch.Tensor,
|
| temperature: float,
|
| chunk_size: int,
|
| logit_scale: float = 1.0,
|
| ) -> tuple[torch.Tensor, torch.Tensor]:
|
| device = last_hidden.device
|
| N, _ = last_hidden.shape
|
| vocab, _ = weight.shape
|
| inv_t = logit_scale / temperature
|
|
|
|
|
| max_old = torch.full((N,), float("-inf"), device=device, dtype=torch.float32)
|
| sum_exp = torch.zeros((N,), device=device, dtype=torch.float32)
|
| x_sum_exp = torch.zeros((N,), device=device, dtype=torch.float32)
|
| target_logit = torch.zeros((N,), device=device, dtype=torch.float32)
|
|
|
|
|
| mm_buf = torch.empty((N, chunk_size), device=device, dtype=last_hidden.dtype)
|
| logits_buf = torch.empty((N, chunk_size), device=device, dtype=torch.float32)
|
|
|
| for start in range(0, vocab, chunk_size):
|
| end = min(start + chunk_size, vocab)
|
| C = end - start
|
|
|
|
|
| w_chunk = weight[start:end].to(last_hidden.dtype)
|
| torch.mm(last_hidden, w_chunk.t(), out=mm_buf[:, :C])
|
| logits_chunk = logits_buf[:, :C]
|
| logits_chunk.copy_(mm_buf[:, :C])
|
| logits_chunk.mul_(inv_t)
|
|
|
|
|
| chunk_max = logits_chunk.amax(dim=-1)
|
| max_new = torch.maximum(max_old, chunk_max)
|
| rescale = torch.exp(max_old - max_new)
|
| chunk_exp = torch.exp(logits_chunk - max_new.unsqueeze(-1))
|
|
|
| sum_exp = sum_exp * rescale + chunk_exp.sum(dim=-1)
|
| x_sum_exp = x_sum_exp * rescale + (chunk_exp * logits_chunk).sum(dim=-1)
|
| max_old = max_new
|
|
|
|
|
| in_chunk_cond = (targets >= start) & (targets < end)
|
| local_idx = torch.clamp(targets - start, 0, end - start - 1)
|
|
|
| target_logit += logits_chunk[torch.arange(N, device=device), local_idx] * in_chunk_cond
|
|
|
| log_z = max_old + torch.log(sum_exp)
|
| logprobs = target_logit - log_z
|
| entropy = log_z - x_sum_exp / sum_exp
|
|
|
| ctx.save_for_backward(last_hidden, weight, targets, log_z)
|
| ctx.temperature = temperature
|
| ctx.chunk_size = chunk_size
|
| ctx.logit_scale = logit_scale
|
|
|
| return logprobs, entropy
|
|
|
| @staticmethod
|
| def backward(ctx, grad_logprobs: torch.Tensor, grad_entropy: torch.Tensor):
|
| hidden, weight, labels, log_z = ctx.saved_tensors
|
| temperature: float = ctx.temperature
|
| chunk_size: int = ctx.chunk_size
|
| logit_scale: float = ctx.logit_scale
|
| inv_t = logit_scale / temperature
|
|
|
| N, _ = hidden.shape
|
| vocab = weight.shape[0]
|
|
|
|
|
| grad_hidden = torch.zeros(hidden.shape, device=hidden.device, dtype=torch.float32)
|
| grad_weight = torch.zeros(weight.shape, device=weight.device, dtype=torch.float32)
|
|
|
|
|
| mm_buf = torch.empty((N, chunk_size), device=hidden.device, dtype=hidden.dtype)
|
| logits_buf = torch.empty((N, chunk_size), device=hidden.device, dtype=torch.float32)
|
|
|
| g = grad_logprobs.to(torch.float32)
|
| row_idx = torch.arange(N, device=hidden.device)
|
|
|
| for start in range(0, vocab, chunk_size):
|
| end = min(start + chunk_size, vocab)
|
| C = end - start
|
| w_chunk = weight[start:end]
|
|
|
| torch.mm(hidden, w_chunk.t(), out=mm_buf[:, :C])
|
| logits_chunk = logits_buf[:, :C]
|
| logits_chunk.copy_(mm_buf[:, :C])
|
| logits_chunk.mul_(inv_t)
|
| probs = torch.exp(logits_chunk - log_z.unsqueeze(-1))
|
|
|
|
|
| grad_logits = (-g).unsqueeze(-1) * probs
|
|
|
| in_chunk_cond = (labels >= start) & (labels < end)
|
| local_idx = torch.clamp(labels - start, 0, end - start - 1)
|
|
|
| grad_logits[row_idx, local_idx] += g * in_chunk_cond
|
| grad_logits = grad_logits * inv_t
|
|
|
| grad_hidden.add_(grad_logits @ w_chunk.float())
|
| grad_weight[start:end].add_(grad_logits.t() @ hidden.float())
|
|
|
| return grad_hidden.to(hidden.dtype), grad_weight.to(weight.dtype), None, None, None, None
|
|
|
|
|
| def patch_chunked_lm_head(model: torch.nn.Module, chunk_size: int, temperature: float) -> None:
|
| if getattr(model.config, "final_logit_softcapping", None) is not None:
|
| raise NotImplementedError(
|
| "The model uses `final_logit_softcapping` which is not yet supported. Please open an issue if you "
|
| "want your model to be supported."
|
| )
|
|
|
| def _chunked_forward(
|
| self: torch.nn.Module,
|
| input_ids: torch.Tensor | None = None,
|
| attention_mask: torch.Tensor | None = None,
|
| labels: torch.Tensor | None = None,
|
| completion_mask: torch.Tensor | None = None,
|
| use_cache: bool = False,
|
| **kwargs,
|
| ) -> dict[str, torch.Tensor]:
|
| assert labels is not None, "requires labels to not be None for logprob computation"
|
|
|
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=use_cache, **kwargs)
|
|
|
| logit_scale = getattr(self.config, "logit_scale", 1.0)
|
| hidden_states = outputs.last_hidden_state
|
|
|
|
|
| hidden_states = hidden_states[:, :-1, :]
|
| labels = labels[:, 1:]
|
|
|
| b, s, h = hidden_states.shape
|
| hidden_flat = hidden_states.reshape(b * s, h).contiguous()
|
| targets_flat = labels.reshape(b * s).contiguous()
|
|
|
|
|
| valid_mask = None
|
| if completion_mask is not None:
|
| completion_mask = completion_mask[:, 1:]
|
| valid_mask = completion_mask.bool().reshape(b * s)
|
| hidden_flat = hidden_flat[valid_mask]
|
| targets_flat = targets_flat[valid_mask]
|
|
|
| logprobs_valid, entropy_valid = _ChunkedLogProbFunction.apply(
|
| hidden_flat, self.lm_head.weight, targets_flat, temperature, chunk_size, logit_scale
|
| )
|
|
|
| if valid_mask is not None:
|
| logprobs = torch.zeros(b * s, device=logprobs_valid.device, dtype=logprobs_valid.dtype)
|
| entropy = torch.zeros(b * s, device=entropy_valid.device, dtype=entropy_valid.dtype)
|
| logprobs[valid_mask] = logprobs_valid
|
| entropy[valid_mask] = entropy_valid
|
| else:
|
| logprobs = logprobs_valid
|
| entropy = entropy_valid
|
|
|
| return {
|
| "log_probs": logprobs.reshape(b, s),
|
| "entropy": entropy.reshape(b, s),
|
| }
|
|
|
| model.forward = types.MethodType(_chunked_forward, model)
|
|
|