""" Utility functions for the training Dream model. References: https://github.com/zhijie-group/Discrete-Diffusion-Forcing/blob/main/D2F-train/utils/loss.py """ import functools import math import numpy as np import torch import torch.nn.functional as F from torch import Tensor @torch.no_grad() def get_prompt_lengths_from_labels( labels: Tensor, attention_mask: Tensor | None = None, ignore_index: int = -100, ) -> Tensor: """ labels: (B, T) int64, with ignore_index where we ignore loss (prompt/user/system/pad) attention_mask: optional (B, T) 1/0; if given, will treat masked-out (0) as non-real tokens Returns: (B,) int64 prompt lengths = index of first (labels != ignore_index) per sample. If a sample has no supervised tokens, length = number of real tokens. """ B, T = labels.shape device = labels.device supervised = labels.ne(ignore_index) if attention_mask is not None: supervised = supervised & attention_mask.bool() idx_grid = torch.arange(T, device=device).expand(B, T) first_idx = torch.where(supervised, idx_grid, torch.full_like(idx_grid, T)).min(dim=1).values if attention_mask is None: return first_idx real_len = attention_mask.sum(dim=1).to(torch.long) return torch.where((labels.ne(ignore_index) & attention_mask.bool()).any(dim=1), first_idx, real_len) @torch.no_grad() def simple_uniform_mask( input_ids: Tensor, # (B, L), int64 prompt_lengths: Tensor, # (B,), int64 — number of tokens to keep unmasked at the left mask_id: int, # token id to write where masked p: float | None = None, # fixed mask rate; if None, sample per-sample in [p_min, p_max] p_min: float = 0.0, p_max: float = 1.0, protect_eos_id: int | None = None, # treated as EOS id for the tail rule pad_id: int | None = None, ensure_at_least_one: bool = True, eps: float = 1e-6, # tiny floor for probabilities ) -> tuple[Tensor, Tensor, Tensor]: """ Returns: noisy: (B, L) int64 — input_ids with some tail tokens replaced by mask_id masked: (B, L) bool — True where we replaced a token (incurs loss) p_samples: (B,) float32 — per-sample mask probabilities used """ B, L = input_ids.shape device = input_ids.device noisy = input_ids.clone() masked = torch.zeros_like(input_ids, dtype=torch.bool) p_mask_tensor = torch.zeros((B, L), device=device, dtype=torch.float32) # choose per-sample p if p is None: p_samples = torch.rand(B, device=device) * (p_max - p_min) + p_min else: p = float(p) p_samples = torch.full((B,), p, device=device) for i in range(B): pl = int(prompt_lengths[i].item()) if pl >= L: continue # nothing to mask # ---- Eligible region: [pl, L). Exclude PAD only here. Do NOT exclude EOS now. ---- tail_tokens = input_ids[i, pl:L] elig = torch.ones_like(tail_tokens, dtype=torch.bool) if pad_id is not None: elig &= tail_tokens != pad_id if not elig.any(): continue # i.i.d. Bernoulli with per-sample prob pi = float(torch.clamp(p_samples[i], eps, 1.0 - eps).item()) randv = torch.rand(elig.shape, device=device) tail_mask = (randv < pi) & elig # optionally guarantee at least one masked token per sample if ensure_at_least_one and not tail_mask.any(): # pick a random eligible index to force-mask idxs = torch.nonzero(elig, as_tuple=False).squeeze(1) force_idx = idxs[torch.randint(0, len(idxs), (1,), device=device)] tail_mask[force_idx] = True # provisional write-back BEFORE EOS rule noisy[i, pl:L] = torch.where( tail_mask, torch.tensor(mask_id, device=device, dtype=noisy.dtype), tail_tokens, ) masked[i, pl:L] = tail_mask p_mask_tensor[i, pl:L] = torch.where(elig, torch.tensor(pi, device=device), torch.tensor(0.0, device=device)) # ---- EOS tail rule (apply only if EOS is distinct from PAD) ---- if protect_eos_id is not None and (pad_id is None or protect_eos_id != pad_id): # Find first EOS at/after prompt eos_positions = input_ids[i, :] == protect_eos_id # First EOS index in the entire sequence if eos_positions.any(): first_eos_idx = int(torch.argmax(eos_positions.to(torch.uint8)).item()) else: first_eos_idx = L # no EOS # Tail exists only if EOS is not the last token if first_eos_idx < L - 1: # Check whether that first EOS was masked was_first_eos_masked = False if first_eos_idx >= pl: was_first_eos_masked = bool(masked[i, first_eos_idx].item()) else: # EOS lies inside the prompt region; it couldn't be masked by the sampling was_first_eos_masked = False # Build tail slice [first_eos_idx, L) tail_slice = slice(first_eos_idx, L) if was_first_eos_masked: # Case A: mask entire EOS tail; loss applies there noisy[i, tail_slice] = torch.tensor(mask_id, device=device, dtype=noisy.dtype) masked[i, tail_slice] = True # For consistency, set per-token prob on the tail to pi where we forced masking p_mask_tensor[i, tail_slice] = pi else: # Case B: force EOS on the tail; no loss there noisy[i, tail_slice] = torch.tensor(protect_eos_id, device=device, dtype=noisy.dtype) masked[i, tail_slice] = False p_mask_tensor[i, tail_slice] = 0.0 return noisy, masked, p_samples def _shift_logits(logits: Tensor) -> Tensor: """ https://github.com/zhijie-group/Discrete-Diffusion-Forcing/blob/eed9750ab081cdc302daa9d8305478988f3f5a17/D2F-train/utils/util.py#L145C1-L150C26 """ shifted_logits = torch.zeros_like(logits) shifted_logits[:, 1:, :] = logits[:, :-1, :] shifted_logits[:, 0, :] = 1.0 return shifted_logits def _context_adaptive_reweight(seq_len: int, distribution: str = "symmetric-geometric", **kwargs) -> Tensor: """ Create context-adaptive reweighting matrix W of shape (seq_len, seq_len) https://github.com/DreamLM/Dream/blob/fd91b8f1d47c5cbe4a8a1674fd9b98045e79d9db/src/trainer/fsdp_sft_trainer.py#L93 """ position_ids_l = np.arange(seq_len).reshape(-1, 1) position_ids_r = np.arange(seq_len).reshape(1, -1) distance = position_ids_l - position_ids_r distance = torch.from_numpy(distance) def geometric_distribution(k, cart_p=0.8, **_): if not 0 < cart_p <= 1: raise ValueError("p must be between 0 and 1") res = (math.log(cart_p) + (k.abs() - 1) * math.log(1 - cart_p)).exp() * 0.5 res.masked_fill_(k == 0, 0) return res if distribution == "symmetric-geometric": matrix = geometric_distribution(distance, **kwargs) else: raise ValueError(f"Unknown distribution {distribution}") return matrix @functools.lru_cache(maxsize=64) def _cached_cart_matrix(seq_len: int, cart_p: float, distribution: str) -> Tensor: """ Get cached context-adaptive reweighting matrix W of shape (seq_len, seq_len) """ W = _context_adaptive_reweight(seq_len, distribution=distribution, cart_p=cart_p) return W # CPU float tensor; we'll .to(device,dtype) at use time def loss_function( logits: Tensor, # (B, L, V) labels: Tensor, # (B, L) masked: Tensor, # (B, L) bool or float; True/1.0 => include in loss vocab_size: int, *, t: Tensor | None = None, # (B,) in [0,1], per-sample time time_weighting: str = "cart", # "none" | "original" | "linear" | "cart" cart_p: float = 0.5, # for cart time weighting cart_distribution: str = "symmetric-geometric", # for cart time weighting token_reweighting: bool = False, # optional difficulty weighting alpha: float = 1.0, gamma: float = 0.0, ignore_index: int = -100, eps: float = 1e-6, ) -> Tensor: """ Cross-entropy on masked positions with optional time- and token-reweighting. time_weighting: - "none": w_t = 1 - "original": w_t = 1 / t - "linear": w_t = 1 - t - "cart": w_t from context-adaptive reweighting matrix We normalize by the sum of (masked * w_t) so the scale stays consistent. """ B, L, _ = logits.shape shifted_logits = _shift_logits(logits) # (B, L, V) # per-token CE without reduction per_tok = F.cross_entropy( shifted_logits.view(-1, vocab_size), labels.view(-1), ignore_index=ignore_index, reduction="none", ).view_as(labels) # (B, L) # base mask: include only selected tokens and not ignore_index base_mask = masked.to(per_tok.dtype) # (B, L) if ignore_index is not None: base_mask = base_mask * (labels.ne(ignore_index)).to(per_tok.dtype) # time weights (per-sample -> per-token broadcast) if t is None or time_weighting == "none": w_t = 1.0 time_w = torch.ones_like(per_tok) else: t = t.to(per_tok.device, dtype=per_tok.dtype) if time_weighting == "original": w_t = 1.0 / t.clamp_min(eps) # upweight small t (early timesteps) time_w = w_t.view(-1, 1).expand_as(per_tok) # (B, L) elif time_weighting == "linear": w_t = (1.0 - t).clamp_min(0.0) # downweight large t time_w = w_t.view(-1, 1).expand_as(per_tok) # (B, L) elif time_weighting == "cart": W = _cached_cart_matrix(L, float(cart_p), str(cart_distribution)).to( per_tok.device, dtype=per_tok.dtype ) # (L, L) w_pos = base_mask @ W.T # (B, L) @ (L, L) -> (B, L) # normalize so mean weight over included tokens is 1 (stable scale) mass = base_mask.sum(dim=1, keepdim=True).clamp_min(1.0) # (B, 1) mean_w = (w_pos * base_mask).sum(dim=1, keepdim=True) / mass # (B, 1) time_w = (w_pos / (mean_w + eps)).where(mass > 0, torch.ones_like(w_pos)) # (B, L) else: raise ValueError(f"Unknown time_weighting: {time_weighting}") weighted = per_tok * base_mask * time_w # optional difficulty-based token reweighting (like alpha*(1-exp(-loss))**gamma * loss) if token_reweighting and gamma != 0.0: weighted = alpha * (1.0 - torch.exp(-weighted)).pow(gamma) * weighted elif token_reweighting: weighted = alpha * weighted # normalize by effective weight mass (masked * time_w), not just masked count denom = (base_mask * time_w).sum().clamp_min(1.0) loss = weighted.sum() / denom return loss