| | """ |
| | Utility functions for the training Dream model. |
| | |
| | References: https://github.com/zhijie-group/Discrete-Diffusion-Forcing/blob/main/D2F-train/utils/loss.py |
| | """ |
| |
|
| | import functools |
| | import math |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| |
|
| |
|
| | @torch.no_grad() |
| | def get_prompt_lengths_from_labels( |
| | labels: Tensor, |
| | attention_mask: Tensor | None = None, |
| | ignore_index: int = -100, |
| | ) -> Tensor: |
| | """ |
| | labels: (B, T) int64, with ignore_index where we ignore loss (prompt/user/system/pad) |
| | attention_mask: optional (B, T) 1/0; if given, will treat masked-out (0) as non-real tokens |
| | |
| | Returns: (B,) int64 prompt lengths = index of first (labels != ignore_index) per sample. |
| | If a sample has no supervised tokens, length = number of real tokens. |
| | """ |
| | B, T = labels.shape |
| | device = labels.device |
| |
|
| | supervised = labels.ne(ignore_index) |
| | if attention_mask is not None: |
| | supervised = supervised & attention_mask.bool() |
| |
|
| | idx_grid = torch.arange(T, device=device).expand(B, T) |
| | first_idx = torch.where(supervised, idx_grid, torch.full_like(idx_grid, T)).min(dim=1).values |
| |
|
| | if attention_mask is None: |
| | return first_idx |
| |
|
| | real_len = attention_mask.sum(dim=1).to(torch.long) |
| | return torch.where((labels.ne(ignore_index) & attention_mask.bool()).any(dim=1), first_idx, real_len) |
| |
|
| |
|
| | @torch.no_grad() |
| | def simple_uniform_mask( |
| | input_ids: Tensor, |
| | prompt_lengths: Tensor, |
| | mask_id: int, |
| | p: float | None = None, |
| | p_min: float = 0.0, |
| | p_max: float = 1.0, |
| | protect_eos_id: int | None = None, |
| | pad_id: int | None = None, |
| | ensure_at_least_one: bool = True, |
| | eps: float = 1e-6, |
| | ) -> tuple[Tensor, Tensor, Tensor]: |
| | """ |
| | Returns: |
| | noisy: (B, L) int64 — input_ids with some tail tokens replaced by mask_id |
| | masked: (B, L) bool — True where we replaced a token (incurs loss) |
| | p_samples: (B,) float32 — per-sample mask probabilities used |
| | """ |
| | B, L = input_ids.shape |
| | device = input_ids.device |
| |
|
| | noisy = input_ids.clone() |
| | masked = torch.zeros_like(input_ids, dtype=torch.bool) |
| | p_mask_tensor = torch.zeros((B, L), device=device, dtype=torch.float32) |
| |
|
| | |
| | if p is None: |
| | p_samples = torch.rand(B, device=device) * (p_max - p_min) + p_min |
| | else: |
| | p = float(p) |
| | p_samples = torch.full((B,), p, device=device) |
| |
|
| | for i in range(B): |
| | pl = int(prompt_lengths[i].item()) |
| | if pl >= L: |
| | continue |
| |
|
| | |
| | tail_tokens = input_ids[i, pl:L] |
| | elig = torch.ones_like(tail_tokens, dtype=torch.bool) |
| | if pad_id is not None: |
| | elig &= tail_tokens != pad_id |
| | if not elig.any(): |
| | continue |
| |
|
| | |
| | pi = float(torch.clamp(p_samples[i], eps, 1.0 - eps).item()) |
| | randv = torch.rand(elig.shape, device=device) |
| | tail_mask = (randv < pi) & elig |
| |
|
| | |
| | if ensure_at_least_one and not tail_mask.any(): |
| | |
| | idxs = torch.nonzero(elig, as_tuple=False).squeeze(1) |
| | force_idx = idxs[torch.randint(0, len(idxs), (1,), device=device)] |
| | tail_mask[force_idx] = True |
| |
|
| | |
| | noisy[i, pl:L] = torch.where( |
| | tail_mask, |
| | torch.tensor(mask_id, device=device, dtype=noisy.dtype), |
| | tail_tokens, |
| | ) |
| | masked[i, pl:L] = tail_mask |
| | p_mask_tensor[i, pl:L] = torch.where(elig, torch.tensor(pi, device=device), torch.tensor(0.0, device=device)) |
| |
|
| | |
| | if protect_eos_id is not None and (pad_id is None or protect_eos_id != pad_id): |
| | |
| | eos_positions = input_ids[i, :] == protect_eos_id |
| | |
| | if eos_positions.any(): |
| | first_eos_idx = int(torch.argmax(eos_positions.to(torch.uint8)).item()) |
| | else: |
| | first_eos_idx = L |
| |
|
| | |
| | if first_eos_idx < L - 1: |
| | |
| | was_first_eos_masked = False |
| | if first_eos_idx >= pl: |
| | was_first_eos_masked = bool(masked[i, first_eos_idx].item()) |
| | else: |
| | |
| | was_first_eos_masked = False |
| |
|
| | |
| | tail_slice = slice(first_eos_idx, L) |
| |
|
| | if was_first_eos_masked: |
| | |
| | noisy[i, tail_slice] = torch.tensor(mask_id, device=device, dtype=noisy.dtype) |
| | masked[i, tail_slice] = True |
| | |
| | p_mask_tensor[i, tail_slice] = pi |
| | else: |
| | |
| | noisy[i, tail_slice] = torch.tensor(protect_eos_id, device=device, dtype=noisy.dtype) |
| | masked[i, tail_slice] = False |
| | p_mask_tensor[i, tail_slice] = 0.0 |
| |
|
| | return noisy, masked, p_samples |
| |
|
| |
|
| | def _shift_logits(logits: Tensor) -> Tensor: |
| | """ |
| | https://github.com/zhijie-group/Discrete-Diffusion-Forcing/blob/eed9750ab081cdc302daa9d8305478988f3f5a17/D2F-train/utils/util.py#L145C1-L150C26 |
| | """ |
| | shifted_logits = torch.zeros_like(logits) |
| | shifted_logits[:, 1:, :] = logits[:, :-1, :] |
| | shifted_logits[:, 0, :] = 1.0 |
| |
|
| | return shifted_logits |
| |
|
| |
|
| | def _context_adaptive_reweight(seq_len: int, distribution: str = "symmetric-geometric", **kwargs) -> Tensor: |
| | """ |
| | Create context-adaptive reweighting matrix W of shape (seq_len, seq_len) |
| | https://github.com/DreamLM/Dream/blob/fd91b8f1d47c5cbe4a8a1674fd9b98045e79d9db/src/trainer/fsdp_sft_trainer.py#L93 |
| | """ |
| | position_ids_l = np.arange(seq_len).reshape(-1, 1) |
| | position_ids_r = np.arange(seq_len).reshape(1, -1) |
| | distance = position_ids_l - position_ids_r |
| | distance = torch.from_numpy(distance) |
| |
|
| | def geometric_distribution(k, cart_p=0.8, **_): |
| | if not 0 < cart_p <= 1: |
| | raise ValueError("p must be between 0 and 1") |
| | res = (math.log(cart_p) + (k.abs() - 1) * math.log(1 - cart_p)).exp() * 0.5 |
| | res.masked_fill_(k == 0, 0) |
| | return res |
| |
|
| | if distribution == "symmetric-geometric": |
| | matrix = geometric_distribution(distance, **kwargs) |
| | else: |
| | raise ValueError(f"Unknown distribution {distribution}") |
| | return matrix |
| |
|
| |
|
| | @functools.lru_cache(maxsize=64) |
| | def _cached_cart_matrix(seq_len: int, cart_p: float, distribution: str) -> Tensor: |
| | """ |
| | Get cached context-adaptive reweighting matrix W of shape (seq_len, seq_len) |
| | """ |
| | W = _context_adaptive_reweight(seq_len, distribution=distribution, cart_p=cart_p) |
| | return W |
| |
|
| |
|
| | def loss_function( |
| | logits: Tensor, |
| | labels: Tensor, |
| | masked: Tensor, |
| | vocab_size: int, |
| | *, |
| | t: Tensor | None = None, |
| | time_weighting: str = "cart", |
| | cart_p: float = 0.5, |
| | cart_distribution: str = "symmetric-geometric", |
| | token_reweighting: bool = False, |
| | alpha: float = 1.0, |
| | gamma: float = 0.0, |
| | ignore_index: int = -100, |
| | eps: float = 1e-6, |
| | ) -> Tensor: |
| | """ |
| | Cross-entropy on masked positions with optional time- and token-reweighting. |
| | time_weighting: |
| | - "none": w_t = 1 |
| | - "original": w_t = 1 / t |
| | - "linear": w_t = 1 - t |
| | - "cart": w_t from context-adaptive reweighting matrix |
| | We normalize by the sum of (masked * w_t) so the scale stays consistent. |
| | """ |
| | B, L, _ = logits.shape |
| | shifted_logits = _shift_logits(logits) |
| |
|
| | |
| | per_tok = F.cross_entropy( |
| | shifted_logits.view(-1, vocab_size), |
| | labels.view(-1), |
| | ignore_index=ignore_index, |
| | reduction="none", |
| | ).view_as(labels) |
| |
|
| | |
| | base_mask = masked.to(per_tok.dtype) |
| | if ignore_index is not None: |
| | base_mask = base_mask * (labels.ne(ignore_index)).to(per_tok.dtype) |
| |
|
| | |
| | if t is None or time_weighting == "none": |
| | w_t = 1.0 |
| | time_w = torch.ones_like(per_tok) |
| | else: |
| | t = t.to(per_tok.device, dtype=per_tok.dtype) |
| | if time_weighting == "original": |
| | w_t = 1.0 / t.clamp_min(eps) |
| | time_w = w_t.view(-1, 1).expand_as(per_tok) |
| | elif time_weighting == "linear": |
| | w_t = (1.0 - t).clamp_min(0.0) |
| | time_w = w_t.view(-1, 1).expand_as(per_tok) |
| | elif time_weighting == "cart": |
| | W = _cached_cart_matrix(L, float(cart_p), str(cart_distribution)).to( |
| | per_tok.device, dtype=per_tok.dtype |
| | ) |
| | w_pos = base_mask @ W.T |
| | |
| | mass = base_mask.sum(dim=1, keepdim=True).clamp_min(1.0) |
| | mean_w = (w_pos * base_mask).sum(dim=1, keepdim=True) / mass |
| | time_w = (w_pos / (mean_w + eps)).where(mass > 0, torch.ones_like(w_pos)) |
| | else: |
| | raise ValueError(f"Unknown time_weighting: {time_weighting}") |
| |
|
| | weighted = per_tok * base_mask * time_w |
| |
|
| | |
| | if token_reweighting and gamma != 0.0: |
| | weighted = alpha * (1.0 - torch.exp(-weighted)).pow(gamma) * weighted |
| | elif token_reweighting: |
| | weighted = alpha * weighted |
| |
|
| | |
| | denom = (base_mask * time_w).sum().clamp_min(1.0) |
| | loss = weighted.sum() / denom |
| | return loss |
| |
|