| """ |
| Noisy state builder for block-AR SAD training. |
| |
| Noising levels per token (uniformly sampled): |
| 0 = clean (keep leaf embedding) |
| 1..L-1 = ancestor level l (sample from LUT, use learnable ancestor embedding) |
| L = mask token |
| |
| AncestorTable provides: |
| - Fixed LUT (indices + probs): which ancestor cluster each token maps to |
| - Learnable ancestor embeddings: the actual embedding used as noisy input |
| """ |
|
|
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from src.diffusion.ancestor_table import AncestorTable |
|
|
|
|
| class NoisyStateBuilder(nn.Module): |
| """ |
| Builds noisy embeddings for vectorized block-AR training. |
| |
| Args: |
| vocab_size: V |
| mask_token_id: id of [MASK] in the leaf vocabulary |
| temp: kept for API compat, unused (temp is baked into LUT probs) |
| top_k_per_level: kept for API compat, unused |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| mask_token_id: int, |
| temp: float = 1.0, |
| top_k_per_level: Optional[List[int]] = None, |
| use_soft_expected: bool = True, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.mask_token_id = mask_token_id |
| |
|
|
| def sample_levels_uniform( |
| self, B: int, S: int, num_total_states: int, device: torch.device |
| ) -> torch.Tensor: |
| """ |
| Sample per-token levels uniformly from {0, 1, ..., num_total_states-1}. |
| |
| num_total_states = num_ancestor_levels + 1 + 1 |
| = (mask_level) where mask_level = num_ancestor_levels + 1 |
| |
| Returns: |
| levels: [B, S] int64 |
| """ |
| return torch.randint(0, num_total_states, (B, S), device=device) |
|
|
| @staticmethod |
| def sample_t( |
| B: int, |
| device: torch.device, |
| eps: float = 1e-3, |
| low_discrepancy: bool = False, |
| rank: int = 0, |
| world_size: int = 1, |
| ) -> torch.Tensor: |
| """ |
| Sample per-sequence noise level t ~ U[eps, 1-eps] (shape [B]). |
| |
| Change 3 (MDLM): low-discrepancy stratified sampling. A single shared |
| phase u is drawn per step; the batch covers the grid {u, u+1/B, ...}. |
| In DDP, the phase is offset by rank/(world_size*B) so that all ranks |
| together cover finer strata (analogous to Kingma's VDM). |
| |
| Change 2 (BD3-LMs / Soft-Masked Diffusion): clamp to [eps, 1-eps] to |
| avoid high-variance ELBO gradients at extreme mask rates (t→0 or t→1). |
| """ |
| if low_discrepancy: |
| |
| u = torch.rand((), device=device) |
| if world_size > 1: |
| |
| u = (u + rank / (world_size * B)) % 1.0 |
| t = torch.arange(B, device=device, dtype=torch.float32) / B |
| t = (t + u) % 1.0 |
| else: |
| t = torch.rand(B, device=device) |
| |
| return (1 - 2 * eps) * t + eps |
|
|
| def sample_levels_hdlm( |
| self, |
| t: torch.Tensor, |
| S: int, |
| num_ancestor_levels: int, |
| gamma: float = 1.0, |
| ) -> torch.Tensor: |
| """ |
| HDLM 3-state (clean / ancestor / mask) hierarchical schedule, |
| γ=1 form from HierarchicalDiffusion.get_alpha_betapi: |
| |
| α_t = 1 - t # P(clean, keep original) |
| c_t = -(1-t) * log(1-t) # P(ancestor state) |
| m_t = t + (1-t) * log(1-t) # P(mask) |
| |
| The three probabilities sum to 1 (with tiny numerical renormalization). |
| |
| Each sequence has its own t (shape [B]); each token within the sequence |
| is drawn i.i.d. from the 3-way categorical. If the token lands on the |
| ancestor state and num_ancestor_levels > 1, c_t is split evenly among |
| levels 1..L. |
| |
| Args: |
| t: [B] in (0, 1), from sample_t() |
| S: sequence length |
| num_ancestor_levels: L = ancestor_table.num_levels (typically 1) |
| gamma: only γ=1 supported |
| |
| Returns: |
| levels: [B, S] int64, values in {0, 1..L, L+1=mask_level} |
| """ |
| if gamma != 1.0: |
| raise NotImplementedError("sample_levels_hdlm only supports γ=1") |
| assert num_ancestor_levels >= 1 |
|
|
| B = t.shape[0] |
| device = t.device |
| eps = 1e-8 |
|
|
| one_m_t = (1.0 - t).clamp(min=eps) |
| log_1m_t = one_m_t.log() |
| alpha_t = one_m_t |
| c_t = -one_m_t * log_1m_t |
| m_t = (t + one_m_t * log_1m_t).clamp(min=0.0) |
|
|
| |
| total = alpha_t + c_t + m_t |
| alpha_t = alpha_t / total |
| c_t = c_t / total |
| m_t = m_t / total |
|
|
| |
| c_per_level = c_t / num_ancestor_levels |
|
|
| |
| probs = torch.stack( |
| [alpha_t] + [c_per_level] * num_ancestor_levels + [m_t], |
| dim=-1, |
| ) |
|
|
| |
| |
| |
| |
| probs_flat = ( |
| probs.unsqueeze(1).expand(B, S, -1).reshape(B * S, -1) |
| ) |
| sampled = torch.multinomial(probs_flat, num_samples=1).squeeze(-1) |
| return sampled.view(B, S).long() |
|
|
| def build_noisy_embeddings( |
| self, |
| input_ids: torch.Tensor, |
| levels: torch.Tensor, |
| ancestor_table: AncestorTable, |
| leaf_embeddings: torch.Tensor, |
| mask_embedding: torch.Tensor, |
| |
| hierarchy=None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]: |
| """ |
| Build noisy embeddings from LUT-based ancestor table. |
| |
| For each position: |
| level 0: leaf_embeddings[token_id] (clean) |
| level 1..L: learnable ancestor embedding sampled via LUT |
| level L+1: mask_embedding (fully masked) |
| |
| Args: |
| input_ids: [B, S] int64 |
| levels: [B, S] int64 (0 .. num_ancestor_levels+1) |
| ancestor_table: AncestorTable |
| leaf_embeddings: [V, d] |
| mask_embedding: [d] |
| hierarchy: ignored (kept for API compat) |
| |
| Returns: |
| noisy_embs: [B, S, d] |
| ancestor_log_probs: [B, S] log-prob of chosen ancestor (0 at clean/mask) |
| ancestor_probs_per_lvl: list of None (not needed downstream) |
| corrupt_mask: [B, S] bool – True at positions level >= 1 |
| """ |
| B, S = input_ids.shape |
| dtype = leaf_embeddings.dtype |
| device = input_ids.device |
|
|
| |
| mask_level = ancestor_table.num_levels + 1 |
|
|
| |
| noisy_embs = leaf_embeddings[input_ids].clone() |
|
|
| ancestor_log_probs = torch.zeros(B, S, device=device, dtype=dtype) |
| corrupt_mask = torch.zeros(B, S, dtype=torch.bool, device=device) |
| ancestor_probs_per_lvl: List[Optional[torch.Tensor]] = [] |
|
|
| |
| mask_pos = (levels == mask_level) |
| if mask_pos.any(): |
| noisy_embs[mask_pos] = mask_embedding.to(dtype) |
| corrupt_mask[mask_pos] = True |
|
|
| for l in range(1, mask_level): |
| pos_l = (levels == l) |
| if not pos_l.any(): |
| ancestor_probs_per_lvl.append(None) |
| continue |
|
|
| flat_ids = input_ids[pos_l] |
| N = flat_ids.shape[0] |
|
|
| |
| lut_idx = ancestor_table.lut_indices(l)[flat_ids] |
| lut_prob = ancestor_table.lut_probs(l)[flat_ids] |
|
|
| sampled_local = torch.multinomial(lut_prob, num_samples=1).squeeze(1) |
| sampled_global = lut_idx[ |
| torch.arange(N, device=device), sampled_local |
| ] |
|
|
| |
| anc_emb = ancestor_table.ancestor_embeddings(l) |
| noisy_embs[pos_l] = anc_emb[sampled_global].to(dtype) |
|
|
| |
| chosen_lp = lut_prob[ |
| torch.arange(N, device=device), sampled_local |
| ].clamp(min=1e-8).log() |
| ancestor_log_probs[pos_l] = chosen_lp.to(dtype) |
|
|
| corrupt_mask[pos_l] = True |
| ancestor_probs_per_lvl.append(None) |
|
|
| return noisy_embs, ancestor_log_probs, ancestor_probs_per_lvl, corrupt_mask |
|
|