high-u's picture
Add 8-bit quantized model
126cf46
"""
Utility functions for the training Dream model.
References: https://github.com/zhijie-group/Discrete-Diffusion-Forcing/blob/main/D2F-train/utils/loss.py
"""
import functools
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
@torch.no_grad()
def get_prompt_lengths_from_labels(
labels: Tensor,
attention_mask: Tensor | None = None,
ignore_index: int = -100,
) -> Tensor:
"""
labels: (B, T) int64, with ignore_index where we ignore loss (prompt/user/system/pad)
attention_mask: optional (B, T) 1/0; if given, will treat masked-out (0) as non-real tokens
Returns: (B,) int64 prompt lengths = index of first (labels != ignore_index) per sample.
If a sample has no supervised tokens, length = number of real tokens.
"""
B, T = labels.shape
device = labels.device
supervised = labels.ne(ignore_index)
if attention_mask is not None:
supervised = supervised & attention_mask.bool()
idx_grid = torch.arange(T, device=device).expand(B, T)
first_idx = torch.where(supervised, idx_grid, torch.full_like(idx_grid, T)).min(dim=1).values
if attention_mask is None:
return first_idx
real_len = attention_mask.sum(dim=1).to(torch.long)
return torch.where((labels.ne(ignore_index) & attention_mask.bool()).any(dim=1), first_idx, real_len)
@torch.no_grad()
def simple_uniform_mask(
input_ids: Tensor, # (B, L), int64
prompt_lengths: Tensor, # (B,), int64 — number of tokens to keep unmasked at the left
mask_id: int, # token id to write where masked
p: float | None = None, # fixed mask rate; if None, sample per-sample in [p_min, p_max]
p_min: float = 0.0,
p_max: float = 1.0,
protect_eos_id: int | None = None, # treated as EOS id for the tail rule
pad_id: int | None = None,
ensure_at_least_one: bool = True,
eps: float = 1e-6, # tiny floor for probabilities
) -> tuple[Tensor, Tensor, Tensor]:
"""
Returns:
noisy: (B, L) int64 — input_ids with some tail tokens replaced by mask_id
masked: (B, L) bool — True where we replaced a token (incurs loss)
p_samples: (B,) float32 — per-sample mask probabilities used
"""
B, L = input_ids.shape
device = input_ids.device
noisy = input_ids.clone()
masked = torch.zeros_like(input_ids, dtype=torch.bool)
p_mask_tensor = torch.zeros((B, L), device=device, dtype=torch.float32)
# choose per-sample p
if p is None:
p_samples = torch.rand(B, device=device) * (p_max - p_min) + p_min
else:
p = float(p)
p_samples = torch.full((B,), p, device=device)
for i in range(B):
pl = int(prompt_lengths[i].item())
if pl >= L:
continue # nothing to mask
# ---- Eligible region: [pl, L). Exclude PAD only here. Do NOT exclude EOS now. ----
tail_tokens = input_ids[i, pl:L]
elig = torch.ones_like(tail_tokens, dtype=torch.bool)
if pad_id is not None:
elig &= tail_tokens != pad_id
if not elig.any():
continue
# i.i.d. Bernoulli with per-sample prob
pi = float(torch.clamp(p_samples[i], eps, 1.0 - eps).item())
randv = torch.rand(elig.shape, device=device)
tail_mask = (randv < pi) & elig
# optionally guarantee at least one masked token per sample
if ensure_at_least_one and not tail_mask.any():
# pick a random eligible index to force-mask
idxs = torch.nonzero(elig, as_tuple=False).squeeze(1)
force_idx = idxs[torch.randint(0, len(idxs), (1,), device=device)]
tail_mask[force_idx] = True
# provisional write-back BEFORE EOS rule
noisy[i, pl:L] = torch.where(
tail_mask,
torch.tensor(mask_id, device=device, dtype=noisy.dtype),
tail_tokens,
)
masked[i, pl:L] = tail_mask
p_mask_tensor[i, pl:L] = torch.where(elig, torch.tensor(pi, device=device), torch.tensor(0.0, device=device))
# ---- EOS tail rule (apply only if EOS is distinct from PAD) ----
if protect_eos_id is not None and (pad_id is None or protect_eos_id != pad_id):
# Find first EOS at/after prompt
eos_positions = input_ids[i, :] == protect_eos_id
# First EOS index in the entire sequence
if eos_positions.any():
first_eos_idx = int(torch.argmax(eos_positions.to(torch.uint8)).item())
else:
first_eos_idx = L # no EOS
# Tail exists only if EOS is not the last token
if first_eos_idx < L - 1:
# Check whether that first EOS was masked
was_first_eos_masked = False
if first_eos_idx >= pl:
was_first_eos_masked = bool(masked[i, first_eos_idx].item())
else:
# EOS lies inside the prompt region; it couldn't be masked by the sampling
was_first_eos_masked = False
# Build tail slice [first_eos_idx, L)
tail_slice = slice(first_eos_idx, L)
if was_first_eos_masked:
# Case A: mask entire EOS tail; loss applies there
noisy[i, tail_slice] = torch.tensor(mask_id, device=device, dtype=noisy.dtype)
masked[i, tail_slice] = True
# For consistency, set per-token prob on the tail to pi where we forced masking
p_mask_tensor[i, tail_slice] = pi
else:
# Case B: force EOS on the tail; no loss there
noisy[i, tail_slice] = torch.tensor(protect_eos_id, device=device, dtype=noisy.dtype)
masked[i, tail_slice] = False
p_mask_tensor[i, tail_slice] = 0.0
return noisy, masked, p_samples
def _shift_logits(logits: Tensor) -> Tensor:
"""
https://github.com/zhijie-group/Discrete-Diffusion-Forcing/blob/eed9750ab081cdc302daa9d8305478988f3f5a17/D2F-train/utils/util.py#L145C1-L150C26
"""
shifted_logits = torch.zeros_like(logits)
shifted_logits[:, 1:, :] = logits[:, :-1, :]
shifted_logits[:, 0, :] = 1.0
return shifted_logits
def _context_adaptive_reweight(seq_len: int, distribution: str = "symmetric-geometric", **kwargs) -> Tensor:
"""
Create context-adaptive reweighting matrix W of shape (seq_len, seq_len)
https://github.com/DreamLM/Dream/blob/fd91b8f1d47c5cbe4a8a1674fd9b98045e79d9db/src/trainer/fsdp_sft_trainer.py#L93
"""
position_ids_l = np.arange(seq_len).reshape(-1, 1)
position_ids_r = np.arange(seq_len).reshape(1, -1)
distance = position_ids_l - position_ids_r
distance = torch.from_numpy(distance)
def geometric_distribution(k, cart_p=0.8, **_):
if not 0 < cart_p <= 1:
raise ValueError("p must be between 0 and 1")
res = (math.log(cart_p) + (k.abs() - 1) * math.log(1 - cart_p)).exp() * 0.5
res.masked_fill_(k == 0, 0)
return res
if distribution == "symmetric-geometric":
matrix = geometric_distribution(distance, **kwargs)
else:
raise ValueError(f"Unknown distribution {distribution}")
return matrix
@functools.lru_cache(maxsize=64)
def _cached_cart_matrix(seq_len: int, cart_p: float, distribution: str) -> Tensor:
"""
Get cached context-adaptive reweighting matrix W of shape (seq_len, seq_len)
"""
W = _context_adaptive_reweight(seq_len, distribution=distribution, cart_p=cart_p)
return W # CPU float tensor; we'll .to(device,dtype) at use time
def loss_function(
logits: Tensor, # (B, L, V)
labels: Tensor, # (B, L)
masked: Tensor, # (B, L) bool or float; True/1.0 => include in loss
vocab_size: int,
*,
t: Tensor | None = None, # (B,) in [0,1], per-sample time
time_weighting: str = "cart", # "none" | "original" | "linear" | "cart"
cart_p: float = 0.5, # for cart time weighting
cart_distribution: str = "symmetric-geometric", # for cart time weighting
token_reweighting: bool = False, # optional difficulty weighting
alpha: float = 1.0,
gamma: float = 0.0,
ignore_index: int = -100,
eps: float = 1e-6,
) -> Tensor:
"""
Cross-entropy on masked positions with optional time- and token-reweighting.
time_weighting:
- "none": w_t = 1
- "original": w_t = 1 / t
- "linear": w_t = 1 - t
- "cart": w_t from context-adaptive reweighting matrix
We normalize by the sum of (masked * w_t) so the scale stays consistent.
"""
B, L, _ = logits.shape
shifted_logits = _shift_logits(logits) # (B, L, V)
# per-token CE without reduction
per_tok = F.cross_entropy(
shifted_logits.view(-1, vocab_size),
labels.view(-1),
ignore_index=ignore_index,
reduction="none",
).view_as(labels) # (B, L)
# base mask: include only selected tokens and not ignore_index
base_mask = masked.to(per_tok.dtype) # (B, L)
if ignore_index is not None:
base_mask = base_mask * (labels.ne(ignore_index)).to(per_tok.dtype)
# time weights (per-sample -> per-token broadcast)
if t is None or time_weighting == "none":
w_t = 1.0
time_w = torch.ones_like(per_tok)
else:
t = t.to(per_tok.device, dtype=per_tok.dtype)
if time_weighting == "original":
w_t = 1.0 / t.clamp_min(eps) # upweight small t (early timesteps)
time_w = w_t.view(-1, 1).expand_as(per_tok) # (B, L)
elif time_weighting == "linear":
w_t = (1.0 - t).clamp_min(0.0) # downweight large t
time_w = w_t.view(-1, 1).expand_as(per_tok) # (B, L)
elif time_weighting == "cart":
W = _cached_cart_matrix(L, float(cart_p), str(cart_distribution)).to(
per_tok.device, dtype=per_tok.dtype
) # (L, L)
w_pos = base_mask @ W.T # (B, L) @ (L, L) -> (B, L)
# normalize so mean weight over included tokens is 1 (stable scale)
mass = base_mask.sum(dim=1, keepdim=True).clamp_min(1.0) # (B, 1)
mean_w = (w_pos * base_mask).sum(dim=1, keepdim=True) / mass # (B, 1)
time_w = (w_pos / (mean_w + eps)).where(mass > 0, torch.ones_like(w_pos)) # (B, L)
else:
raise ValueError(f"Unknown time_weighting: {time_weighting}")
weighted = per_tok * base_mask * time_w
# optional difficulty-based token reweighting (like alpha*(1-exp(-loss))**gamma * loss)
if token_reweighting and gamma != 0.0:
weighted = alpha * (1.0 - torch.exp(-weighted)).pow(gamma) * weighted
elif token_reweighting:
weighted = alpha * weighted
# normalize by effective weight mass (masked * time_w), not just masked count
denom = (base_mask * time_w).sum().clamp_min(1.0)
loss = weighted.sum() / denom
return loss