| """ |
| Utility functions for the training Dream model. |
| |
| References: https://github.com/zhijie-group/Discrete-Diffusion-Forcing/blob/main/D2F-train/utils/loss.py |
| """ |
|
|
| import functools |
| import math |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
|
|
| @torch.no_grad() |
| def get_prompt_lengths_from_labels( |
| labels: Tensor, |
| attention_mask: Tensor | None = None, |
| ignore_index: int = -100, |
| ) -> Tensor: |
| """ |
| labels: (B, T) int64, with ignore_index where we ignore loss (prompt/user/system/pad) |
| attention_mask: optional (B, T) 1/0; if given, will treat masked-out (0) as non-real tokens |
| |
| Returns: (B,) int64 prompt lengths = index of first (labels != ignore_index) per sample. |
| If a sample has no supervised tokens, length = number of real tokens. |
| """ |
| B, T = labels.shape |
| device = labels.device |
|
|
| supervised = labels.ne(ignore_index) |
| if attention_mask is not None: |
| supervised = supervised & attention_mask.bool() |
|
|
| idx_grid = torch.arange(T, device=device).expand(B, T) |
| first_idx = torch.where(supervised, idx_grid, torch.full_like(idx_grid, T)).min(dim=1).values |
|
|
| if attention_mask is None: |
| return first_idx |
|
|
| real_len = attention_mask.sum(dim=1).to(torch.long) |
| return torch.where((labels.ne(ignore_index) & attention_mask.bool()).any(dim=1), first_idx, real_len) |
|
|
|
|
| @torch.no_grad() |
| def simple_uniform_mask( |
| input_ids: Tensor, |
| prompt_lengths: Tensor, |
| mask_id: int, |
| p: float | None = None, |
| p_min: float = 0.0, |
| p_max: float = 1.0, |
| protect_eos_id: int | None = None, |
| pad_id: int | None = None, |
| ensure_at_least_one: bool = True, |
| eps: float = 1e-6, |
| ) -> tuple[Tensor, Tensor, Tensor]: |
| """ |
| Returns: |
| noisy: (B, L) int64 — input_ids with some tail tokens replaced by mask_id |
| masked: (B, L) bool — True where we replaced a token (incurs loss) |
| p_samples: (B,) float32 — per-sample mask probabilities used |
| """ |
| B, L = input_ids.shape |
| device = input_ids.device |
|
|
| noisy = input_ids.clone() |
| masked = torch.zeros_like(input_ids, dtype=torch.bool) |
| p_mask_tensor = torch.zeros((B, L), device=device, dtype=torch.float32) |
|
|
| |
| if p is None: |
| p_samples = torch.rand(B, device=device) * (p_max - p_min) + p_min |
| else: |
| p = float(p) |
| p_samples = torch.full((B,), p, device=device) |
|
|
| for i in range(B): |
| pl = int(prompt_lengths[i].item()) |
| if pl >= L: |
| continue |
|
|
| |
| tail_tokens = input_ids[i, pl:L] |
| elig = torch.ones_like(tail_tokens, dtype=torch.bool) |
| if pad_id is not None: |
| elig &= tail_tokens != pad_id |
| if not elig.any(): |
| continue |
|
|
| |
| pi = float(torch.clamp(p_samples[i], eps, 1.0 - eps).item()) |
| randv = torch.rand(elig.shape, device=device) |
| tail_mask = (randv < pi) & elig |
|
|
| |
| if ensure_at_least_one and not tail_mask.any(): |
| |
| idxs = torch.nonzero(elig, as_tuple=False).squeeze(1) |
| force_idx = idxs[torch.randint(0, len(idxs), (1,), device=device)] |
| tail_mask[force_idx] = True |
|
|
| |
| noisy[i, pl:L] = torch.where( |
| tail_mask, |
| torch.tensor(mask_id, device=device, dtype=noisy.dtype), |
| tail_tokens, |
| ) |
| masked[i, pl:L] = tail_mask |
| p_mask_tensor[i, pl:L] = torch.where(elig, torch.tensor(pi, device=device), torch.tensor(0.0, device=device)) |
|
|
| |
| if protect_eos_id is not None and (pad_id is None or protect_eos_id != pad_id): |
| |
| eos_positions = input_ids[i, :] == protect_eos_id |
| |
| if eos_positions.any(): |
| first_eos_idx = int(torch.argmax(eos_positions.to(torch.uint8)).item()) |
| else: |
| first_eos_idx = L |
|
|
| |
| if first_eos_idx < L - 1: |
| |
| was_first_eos_masked = False |
| if first_eos_idx >= pl: |
| was_first_eos_masked = bool(masked[i, first_eos_idx].item()) |
| else: |
| |
| was_first_eos_masked = False |
|
|
| |
| tail_slice = slice(first_eos_idx, L) |
|
|
| if was_first_eos_masked: |
| |
| noisy[i, tail_slice] = torch.tensor(mask_id, device=device, dtype=noisy.dtype) |
| masked[i, tail_slice] = True |
| |
| p_mask_tensor[i, tail_slice] = pi |
| else: |
| |
| noisy[i, tail_slice] = torch.tensor(protect_eos_id, device=device, dtype=noisy.dtype) |
| masked[i, tail_slice] = False |
| p_mask_tensor[i, tail_slice] = 0.0 |
|
|
| return noisy, masked, p_samples |
|
|
|
|
| def _shift_logits(logits: Tensor) -> Tensor: |
| """ |
| https://github.com/zhijie-group/Discrete-Diffusion-Forcing/blob/eed9750ab081cdc302daa9d8305478988f3f5a17/D2F-train/utils/util.py#L145C1-L150C26 |
| """ |
| shifted_logits = torch.zeros_like(logits) |
| shifted_logits[:, 1:, :] = logits[:, :-1, :] |
| shifted_logits[:, 0, :] = 1.0 |
|
|
| return shifted_logits |
|
|
|
|
| def _context_adaptive_reweight(seq_len: int, distribution: str = "symmetric-geometric", **kwargs) -> Tensor: |
| """ |
| Create context-adaptive reweighting matrix W of shape (seq_len, seq_len) |
| https://github.com/DreamLM/Dream/blob/fd91b8f1d47c5cbe4a8a1674fd9b98045e79d9db/src/trainer/fsdp_sft_trainer.py#L93 |
| """ |
| position_ids_l = np.arange(seq_len).reshape(-1, 1) |
| position_ids_r = np.arange(seq_len).reshape(1, -1) |
| distance = position_ids_l - position_ids_r |
| distance = torch.from_numpy(distance) |
|
|
| def geometric_distribution(k, cart_p=0.8, **_): |
| if not 0 < cart_p <= 1: |
| raise ValueError("p must be between 0 and 1") |
| res = (math.log(cart_p) + (k.abs() - 1) * math.log(1 - cart_p)).exp() * 0.5 |
| res.masked_fill_(k == 0, 0) |
| return res |
|
|
| if distribution == "symmetric-geometric": |
| matrix = geometric_distribution(distance, **kwargs) |
| else: |
| raise ValueError(f"Unknown distribution {distribution}") |
| return matrix |
|
|
|
|
| @functools.lru_cache(maxsize=64) |
| def _cached_cart_matrix(seq_len: int, cart_p: float, distribution: str) -> Tensor: |
| """ |
| Get cached context-adaptive reweighting matrix W of shape (seq_len, seq_len) |
| """ |
| W = _context_adaptive_reweight(seq_len, distribution=distribution, cart_p=cart_p) |
| return W |
|
|
|
|
| def loss_function( |
| logits: Tensor, |
| labels: Tensor, |
| masked: Tensor, |
| vocab_size: int, |
| *, |
| t: Tensor | None = None, |
| time_weighting: str = "cart", |
| cart_p: float = 0.5, |
| cart_distribution: str = "symmetric-geometric", |
| token_reweighting: bool = False, |
| alpha: float = 1.0, |
| gamma: float = 0.0, |
| ignore_index: int = -100, |
| eps: float = 1e-6, |
| ) -> Tensor: |
| """ |
| Cross-entropy on masked positions with optional time- and token-reweighting. |
| time_weighting: |
| - "none": w_t = 1 |
| - "original": w_t = 1 / t |
| - "linear": w_t = 1 - t |
| - "cart": w_t from context-adaptive reweighting matrix |
| We normalize by the sum of (masked * w_t) so the scale stays consistent. |
| """ |
| B, L, _ = logits.shape |
| shifted_logits = _shift_logits(logits) |
|
|
| |
| per_tok = F.cross_entropy( |
| shifted_logits.view(-1, vocab_size), |
| labels.view(-1), |
| ignore_index=ignore_index, |
| reduction="none", |
| ).view_as(labels) |
|
|
| |
| base_mask = masked.to(per_tok.dtype) |
| if ignore_index is not None: |
| base_mask = base_mask * (labels.ne(ignore_index)).to(per_tok.dtype) |
|
|
| |
| if t is None or time_weighting == "none": |
| w_t = 1.0 |
| time_w = torch.ones_like(per_tok) |
| else: |
| t = t.to(per_tok.device, dtype=per_tok.dtype) |
| if time_weighting == "original": |
| w_t = 1.0 / t.clamp_min(eps) |
| time_w = w_t.view(-1, 1).expand_as(per_tok) |
| elif time_weighting == "linear": |
| w_t = (1.0 - t).clamp_min(0.0) |
| time_w = w_t.view(-1, 1).expand_as(per_tok) |
| elif time_weighting == "cart": |
| W = _cached_cart_matrix(L, float(cart_p), str(cart_distribution)).to( |
| per_tok.device, dtype=per_tok.dtype |
| ) |
| w_pos = base_mask @ W.T |
| |
| mass = base_mask.sum(dim=1, keepdim=True).clamp_min(1.0) |
| mean_w = (w_pos * base_mask).sum(dim=1, keepdim=True) / mass |
| time_w = (w_pos / (mean_w + eps)).where(mass > 0, torch.ones_like(w_pos)) |
| else: |
| raise ValueError(f"Unknown time_weighting: {time_weighting}") |
|
|
| weighted = per_tok * base_mask * time_w |
|
|
| |
| if token_reweighting and gamma != 0.0: |
| weighted = alpha * (1.0 - torch.exp(-weighted)).pow(gamma) * weighted |
| elif token_reweighting: |
| weighted = alpha * weighted |
|
|
| |
| denom = (base_mask * time_w).sum().clamp_min(1.0) |
| loss = weighted.sum() / denom |
| return loss |
|
|